크고 복잡한 모델을 이용해 작고 가벼운 모델을 효과적으로 학습할 수 있는 기법인 Knowledge Distillation에 대한 설명

Knowledge Distillation 개요

  • Knowledge Distillation(지식 증류)란 미리 잘 학습된 큰 네트워크(Teacher network) 의 지식을 실제로 사용하고자 하는 작은 네트워크(Student network) 에게 전달하여 더 좋은 성능을 내도록 학습하는 것.

  • Knowledge Distillation은 단순히 라벨(hard target, [0, 1, 0, 0])을 학습하는 것이 아닌 Teacher network의 더 풍부한 정보를 Student network에 전달하여 더 일반화된 학습을 할 수 있도록 한다.

  • 따라서, 작은 네트워크도 큰 네트워크와 비슷한 성능을 낼 수 있도록 하여 computing resource 측면에서 더 효율적인 시스템을 구성할 수 있게 해준다.

  • Paper: Distilling the Knowledge in a Neural Network, 2014


Knowledge Distillation 방법

학습 구조

  • Teacher model은 훈련 데이터셋을 통해 사전학습된 크고 복잡한 모델이고, Student model은 학습(지식 증류) 대상인 작고 가벼운 모델을 말한다.
  • 기본 개념은 Techer model의 결과와 Student model의 출력를 이용해 Loss(KL Divergence Loss)를 구하고 역전파를 통해 Student model이 학습되는 것이며, 이 과정에서 label(정답) 데이터가 사용되지 않으므로 비지도 학습이라 할 수 있다.

  • 하지만, 일반적으로 Techer model을 학습시키기 위한 label 데이터가 존재하며 label 정보를 이용해 아래와 같은 방법으로 Knowledge Distillation 학습이 이루어 진다.

+full

  • 두 가지 손실함수(loss)를 통해 Student model이 학습되는데, Teacher model과 Student model의 soft label의 분포 차이를 비교하는 distillation loss와 Student model의 hard label 출력과 정답을 비교하기 위한 student loss가 사용 된다.
  • Distillation Loss
    • 보통 KL Divergence 사용
    • logits를 softmax 후 비교하며, temperature(T) 를 도입해 더 “부드러운” 분포를 만듦
    • Teacher model의 soft prediction과 Student model의 soft prediction 분포의 차이가 Loss
    • Teacher model의 지식을 모방하기 위한 Loss
  • Student Loss
    • CrossEntropy 사용
    • Student model의 True label에 대한 Confidence의 음의 값이 Loss
    • True Label을 예측하기 위한 Loss
  • Total Loss
    • 두 loss를 적절히 활용한 최종 Loss
    • Weighted sum of two Losses

Soft labels

  • 먼저, Hard label은 one-hot encoding 형태로 표현되며 정확히 어떤 클래스가 정답인지만 고려하고 클래스 간 유사도 정보는 포함하지 않는다. 반면 Soft label은 0~1 사이의 클래스간 상대적 확률 분포로 표현하며 주로 모델의 최종 Softmax 함수 를 통해 출력된다.
hard_label = [0, 1, 0, 0, 0]
soft_label = [0.1, 0.7, 0.1, 0.05, 0.05]
  • 일반적인 Softmax 함수는 아래와 같이 정의 된다.
  • 위 Softmax 함수의 결과와 같이 logit(Softmax의 입력) 간의 차이가 크면 매우 sharp한 결과를 보이며, 이는 hard label을 이용해 학습하는 것과 유사하게 일반화 보다는 정답만을 쫓게 된다.
  • 따라서 Distillation Loss를 계산할 때, Softmax 함수에 Temperature 상수를 추가하여 분포의 sharpness를 조절함으로써 student가 단순한 정답 암기보다 더 넓은 관점으로 학습하도록 돕도록 한다.
  • 부드러운(soft) 결과가 Student가 더 많은 정보(유사 클래스 간 관계 등)를 학습하도록 유도함.

Total Loss (Distillation Loss)

  • Knowledge Distillation에서는 Student LossDistillation Loss가중합하여 student를 학습한다.
  • 여기서,
    • : Teachers logits
    • : Student logits
    • : Temperature
    • : Cross-Entropy with hard label(GT)
    • : KL divergence (soft target distillation loss)
    • : Student Loss(hard label) 비중
  • 위 식에서 항이 들어간 이유는 softmax에 temperature 를 적용하면 미분값 gradient가 만큼 작아 진다.
  • 따라서, Loss 값 자체에 을 곱하여 gradient scale의 균형을 맞추어 준다. (Gradient가 너무 작아져서 학습이 안되는 것을 방지함)

Knowledge Distillation 활용

  • Knowledge Distillation은 특히 성능과 경량화를 동시에 요구하는 영역, 멀티모달 학습, 다중 모델 협력 등에 매우 효과적이다.
  • 비슷한 아키텍처에서 더 작은 모델로의 Knowledge Distillation 응용 방법으로, Resnet50resnet18 , BERTDistilBERT 등의 예가 있다.
  • 멀티모달 학습에서는 텍스트, 이미지, 오디오 등 여러 모달리티를 가진 모델에서, 하나의 모달리티를 teacher로 삼아 다른 모달리티를 가이드하는 방법이 활용된다.
  • 모델 양자화에서 응용 방법으로는 양자화 캘리브레이션을 통해 PTQ 이후 미세 조정(Fine-tuning) 단계에서 양자화된 모델(int8)을 student로, 기존 모델(float32)을 teacher로 하여 distillation 학습을 활용하여 PTQ로 인한 정확도 손실을 복구할 수 있다. (PTQ + Distillation)

Distillation Loss Python code

def distillation_loss(student_logits, teacher_logits, labels, T=4.0, alpha=0.3):
    # Student loss: CrossEntropy
    hard_loss = F.cross_entropy(student_logits, labels)
    
    # Distillation loss: KLDiv with softmax + temperature
    soft_teacher = F.softmax(teacher_logits / T, dim=1)
    soft_student = F.log_softmax(student_logits / T, dim=1)
    soft_loss = F.kl_div(soft_student, soft_teacher, reduction='batchmean') * (T ** 2)
    
    # Total combined loss
    return alpha * hard_loss + (1 - alpha) * soft_loss

참고