크고 복잡한 모델을 이용해 작고 가벼운 모델을 효과적으로 학습할 수 있는 기법인 Knowledge Distillation에 대한 설명
Knowledge Distillation 개요
-
Knowledge Distillation(지식 증류)란 미리 잘 학습된 큰 네트워크(Teacher network) 의 지식을 실제로 사용하고자 하는 작은 네트워크(Student network) 에게 전달하여 더 좋은 성능을 내도록 학습하는 것.
-
Knowledge Distillation은 단순히 라벨(hard target, [0, 1, 0, 0])을 학습하는 것이 아닌 Teacher network의 더 풍부한 정보를 Student network에 전달하여 더 일반화된 학습을 할 수 있도록 한다.
-
따라서, 작은 네트워크도 큰 네트워크와 비슷한 성능을 낼 수 있도록 하여 computing resource 측면에서 더 효율적인 시스템을 구성할 수 있게 해준다.
Knowledge Distillation 방법
학습 구조
- Teacher model은 훈련 데이터셋을 통해 사전학습된 크고 복잡한 모델이고, Student model은 학습(지식 증류) 대상인 작고 가벼운 모델을 말한다.
- 기본 개념은 Techer model의 결과와 Student model의 출력를 이용해 Loss(KL Divergence Loss)를 구하고 역전파를 통해 Student model이 학습되는 것이며, 이 과정에서 label(정답) 데이터가 사용되지 않으므로 비지도 학습이라 할 수 있다.
- 하지만, 일반적으로 Techer model을 학습시키기 위한 label 데이터가 존재하며 label 정보를 이용해 아래와 같은 방법으로 Knowledge Distillation 학습이 이루어 진다.
- 두 가지 손실함수(loss)를 통해 Student model이 학습되는데, Teacher model과 Student model의 soft label의 분포 차이를 비교하는 distillation loss와 Student model의 hard label 출력과 정답을 비교하기 위한 student loss가 사용 된다.
- Distillation Loss
- 보통 KL Divergence 사용
- logits를 softmax 후 비교하며, temperature(T) 를 도입해 더 “부드러운” 분포를 만듦
- Teacher model의 soft prediction과 Student model의 soft prediction 분포의 차이가 Loss
- Teacher model의 지식을 모방하기 위한 Loss
- Student Loss
- CrossEntropy 사용
- Student model의 True label에 대한 Confidence의 음의 값이 Loss
- True Label을 예측하기 위한 Loss
- Total Loss
- 두 loss를 적절히 활용한 최종 Loss
- Weighted sum of two Losses
Soft labels
- 먼저, Hard label은 one-hot encoding 형태로 표현되며 정확히 어떤 클래스가 정답인지만 고려하고 클래스 간 유사도 정보는 포함하지 않는다. 반면 Soft label은 0~1 사이의 클래스간 상대적 확률 분포로 표현하며 주로 모델의 최종 Softmax 함수 를 통해 출력된다.
hard_label = [0, 1, 0, 0, 0]
soft_label = [0.1, 0.7, 0.1, 0.05, 0.05]
- 일반적인 Softmax 함수는 아래와 같이 정의 된다.
- 위 Softmax 함수의 결과와 같이 logit(Softmax의 입력) 간의 차이가 크면 매우 sharp한 결과를 보이며, 이는 hard label을 이용해 학습하는 것과 유사하게 일반화 보다는 정답만을 쫓게 된다.
- 따라서 Distillation Loss를 계산할 때, Softmax 함수에 Temperature 상수를 추가하여 분포의 sharpness를 조절함으로써 student가 단순한 정답 암기보다 더 넓은 관점으로 학습하도록 돕도록 한다.
- 부드러운(soft) 결과가 Student가 더 많은 정보(유사 클래스 간 관계 등)를 학습하도록 유도함.
Total Loss (Distillation Loss)
- Knowledge Distillation에서는 Student Loss와 Distillation Loss의 가중합하여 student를 학습한다.
- 여기서,
- : Teachers logits
- : Student logits
- : Temperature
- : Cross-Entropy with hard label(GT)
- : KL divergence (soft target distillation loss)
- : Student Loss(hard label) 비중
- 위 식에서 항이 들어간 이유는 softmax에 temperature 를 적용하면 미분값 gradient가 만큼 작아 진다.
- 따라서, Loss 값 자체에 을 곱하여 gradient scale의 균형을 맞추어 준다. (Gradient가 너무 작아져서 학습이 안되는 것을 방지함)
Knowledge Distillation 활용
- Knowledge Distillation은 특히 성능과 경량화를 동시에 요구하는 영역, 멀티모달 학습, 다중 모델 협력 등에 매우 효과적이다.
- 비슷한 아키텍처에서 더 작은 모델로의 Knowledge Distillation 응용 방법으로,
Resnet50
→resnet18
,BERT
→DistilBERT
등의 예가 있다. - 멀티모달 학습에서는 텍스트, 이미지, 오디오 등 여러 모달리티를 가진 모델에서, 하나의 모달리티를 teacher로 삼아 다른 모달리티를 가이드하는 방법이 활용된다.
- 모델 양자화에서 응용 방법으로는 양자화 캘리브레이션을 통해 PTQ 이후 미세 조정(Fine-tuning) 단계에서 양자화된 모델(int8)을 student로, 기존 모델(float32)을 teacher로 하여 distillation 학습을 활용하여 PTQ로 인한 정확도 손실을 복구할 수 있다. (PTQ + Distillation)
Distillation Loss Python code
def distillation_loss(student_logits, teacher_logits, labels, T=4.0, alpha=0.3):
# Student loss: CrossEntropy
hard_loss = F.cross_entropy(student_logits, labels)
# Distillation loss: KLDiv with softmax + temperature
soft_teacher = F.softmax(teacher_logits / T, dim=1)
soft_student = F.log_softmax(student_logits / T, dim=1)
soft_loss = F.kl_div(soft_student, soft_teacher, reduction='batchmean') * (T ** 2)
# Total combined loss
return alpha * hard_loss + (1 - alpha) * soft_loss
참고
- Knowledge Distillation - Neural Network Distiller
- Knowledge Distillation. Knowledge distillation is model… | by Ujjwal Upadhyay | Neural Machine | Medium
- Knowledge Distillation
- Knowledge Distillation Tutorial — PyTorch Tutorials 2.7.0+cu126 documentation
- 딥러닝 용어 정리, Knowledge distillation 설명과 이해
- 딥러닝 모델 지식의 증류기법, Knowledge Distillation | Seongsu