Pre-trained ResNet 모델을 불러와 사용자 데이터셋으로 파인튜닝(fine-tuning)하는 코드를 작성하고, 학습과 관련된 여러 pytorch 함수에 대해 알아봄.

실행환경

  • CPU: Intel(R) Xeon(R) w5-3423 @ 12cores
  • RAM: 256G
  • GPU: NVIDIA RTX A6000 48G x 2ea
  • OS: Ubuntu-22.04 LTS
  • Python: 3.11.9
  • pytorch: 2.5.0+cu118 (설치 바로가기)
  • torchvision: 0.20.0+cu118

Python Code

1. Load Pre-Trained Model

  • ImageNet-1K 데이터셋을 이용해 사전학습된 ResNet18 모델을 불러온다.
  • ResNet18 모델(아키텍처)은 torchvision에서 제공함.
  • ResNet 참고: Classification Models: ResNet
import torch
import torchvision.models as models
 
# Load the pre-trained ResNet-18 model  
model = models.resnet18(weights= models.ResNet18_Weights.IMAGENET1K_V1)

2. Modify the last layer

  • 마지막 FC(Fully Connected) 레이어를 클래스 수에 맞게 출력되도록 수정함.
# Modify the last layer of the model  
num_classes = 10 # replace with the number of classes in your dataset  
model.fc = torch.nn.Linear(model.fc.in_features, num_classes)

3. Load the custom dataset and and pre-process

  • Pytorch의 DataLoader 모듈을 이용해서 커스텀 데이터셋을 불러옴. 데이터셋의 폴더 구조는 아래와 같다.
custom_dataset/  
├── train/  
│ ├── class1/  
│ ├── class2/  
│ ├── ...  
├── val/  
│ ├── class1/  
│ ├── class2/  
│ ├── ...  
├── test/  
│ ├── class1/  
│ ├── class2/  
│ ├── ...
  • 아래 코드를 이용해 train, val, test 데이터를 불러옴.
  • transform 설정을 통해 데이터 전처리에 대한 정의를 하며, 전처리 과정은 분류 목적과 데이터셋에 맞게 커스텀해야한다.
  • * torchvision 에서 제공하는 ResNet 모델의 경우, FC 레이어 전에 AdaptiveAvgPool2d() 레이어가 적용되어 있으므로 입력 데이터 사이즈의 제약을 받지 않는다. AdaptiveAvgPool2d() 레이어는 FC 레이어에 전달되는 피처맵의 크기가 항상 일정하게 보장되도록함.
from torchvision.datasets import ImageFolder  
from torchvision.transforms import transforms  
from torch.utils.data import DataLoader
 
# Define the transformations to apply to the images  
transform = transforms.Compose([  
transforms.Resize(256), # 기존 이미지 비율에 맞게, 긴 변을 256으로 리사이즈함
transforms.CenterCrop(224), # 본인의 데이터셋이 전체 이미지를 모두 보아야 하는 경우 생략 필요
# 람다함수를 이용해 원하는 영역 크롭 가능
#transforms.Lambda(lambda img: F.crop(img, top=h//2, left=0, height=h//2, width=w)), 
transforms.ToTensor(),  # ToTensor() 함수 적용 시 기본으로 0~255 -> 0~1 범위로 정규화됨
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) # 이미지넷 평균 정규화
])
 
# Load the train and validation datasets  
train_dataset = ImageFolder('/Users/pytorch_zero_to_hero/animal_dataset/train', transform=transform)  
val_dataset = ImageFolder('/Users/pytorch_zero_to_hero/animal_dataset/val', transform=transform)  
test_dataset = ImageFolder('/Users/pytorch_zero_to_hero/animal_dataset/test', transform=transform)  
  
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)  
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)  
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)

4. Define the loss function and optimizer

# loss함수에서 torch.nn.CrossEntropyLoss를 사용하면 신경망에 softmax를 넣을 필요는 없다
criterion = torch.nn.CrossEntropyLoss()
 
## 전체 레이어 학습
optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
## 마지막 레이어만 학습할 경우
# optimizer = torch.optim.SGD(model.fc.parameters(), lr=0.001, momentum=0.9)
## 특정 레이어만 학습할 경우
# for param in model.parameters():
#     param.requires_grad = False
# for param in model.layer4.parameters():
#     param.requires_grad = True
# for param in model.fc.parameters():
#     param.requires_grad = True
# optimizer = optim.SGD([
#     {'params': model.layer4.parameters(), 'lr': 0.0005},
#     {'params': model.fc.parameters(), 'lr': 0.01},
# ], momentum=0.9)

5. Define a function to Train the model

  • num_epochs 만큼 학습이 이루어지고, 각 학습별 vali
def train(model, train_loader, val_loader, criterion, optimizer, num_epochs):
    # Determine whether to use GPU (if available) or CPU
    device = torch.device("mps" if torch.backends.mps.is_available() else "cuda" if torch.cuda.is_available() else "cpu")
 
    for epoch in range(num_epochs):
        # Set the model to training mode
        model.train()
 
        # Initialize running loss and correct predictions count for training
        running_loss = 0.0
        running_corrects = 0
 
        # Iterate over the training data loader
        for inputs, labels in train_loader:
            # Move inputs and labels to the device (GPU or CPU)
            inputs = inputs.to(device)
            labels = labels.to(device)
 
            # Reset the gradients to zero before the backward pass
            optimizer.zero_grad()
 
            # Forward pass: compute the model output
            outputs = model(inputs)
            # Get the predicted class (with the highest score)
            _, preds = torch.max(outputs, 1)
            # Compute the loss between the predictions and actual labels
            loss = criterion(outputs, labels)
 
            # Backward pass: compute gradients
            loss.backward()
            # Perform the optimization step to update model parameters
            optimizer.step()
 
            # Accumulate the running loss and the number of correct predictions
            running_loss += loss.item() * inputs.size(0)
            running_corrects += torch.sum(preds == labels.data)
 
        # Compute average training loss and accuracy for this epoch
        train_loss = running_loss / len(train_loader.dataset)
        train_acc = running_corrects.float() / len(train_loader.dataset)
 
        # Set the model to evaluation mode for validation
        model.eval()
        # Initialize running loss and correct predictions count for validation
        running_loss = 0.0
        running_corrects = 0
 
        # Disable gradient computation for validation (saves memory and computations)
        with torch.no_grad():
            # Iterate over the validation data loader
            for inputs, labels in val_loader:
                # Move inputs and labels to the device (GPU or CPU)
                inputs = inputs.to(device)
                labels = labels.to(device)
 
                # Forward pass: compute the model output
                outputs = model(inputs)
                # Get the predicted class (with the highest score)
                _, preds = torch.max(outputs, 1)
                # Compute the loss between the predictions and actual labels
                loss = criterion(outputs, labels)
 
                # Accumulate the running loss and the number of correct predictions
                running_loss += loss.item() * inputs.size(0)
                running_corrects += torch.sum(preds == labels.data)
 
        # Compute average validation loss and accuracy for this epoch
        val_loss = running_loss / len(val_loader.dataset)
        val_acc = running_corrects.float() / len(val_loader.dataset)
 
        # Print the results for the current epoch
        print(f'Epoch [{epoch+1}/{num_epochs}], train loss: {train_loss:.4f}, train acc: {train_acc:.4f}, val loss: {val_loss:.4f}, val acc: {val_acc:.4f}')
	    
	    # Save the model if this is the best so far
        if val_acc > best_val_acc:
            best_val_acc = val_acc
            torch.save(model.state_dict(), f'{save_path}best.pth')
            print(f"Saved best model with val acc: {best_val_acc:.4f}")
 
        # Save every N epochs
        if save_epoch != 0:
            if (epoch + 1) % save_epoch == 0:
                torch.save(model.state_dict(), f'{save_path}_epoch{epoch+1}.pth')
                print(f"Saved model at epoch {epoch+1}")
 
        if val_loss < 0.001:
            break

6. Fine-tune the model on the custom dataset

  • device 정의 및 학습 실행.
device = torch.device("mps" if torch.backends.mps.is_available() else "cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)  
train(model, train_loader, val_loader, criterion, optimizer, num_epochs=5)

7. Evaluate the Model

  • evaluate 함수 정의 및 평가 실행.
def evaluate(model, test_loader, device, save_file_name):
    # Initialize dictionaries to store correct and total predictions
    correct_pred = {classname: 0 for classname in test_loader.dataset.classes}
    total_pred = {classname: 0 for classname in test_loader.dataset.classes}
    
    correct_score_pred = {classname: 0 for classname in test_loader.dataset.classes}
  
 
    # Set the model to evaluation mode
    model.eval()
 
    # Track the ground truth labels and predictions
    all_labels = []
    all_preds = []
    all_scores = []
 
    with torch.no_grad():
        for inputs, labels in test_loader:
            # Move the inputs and labels to the device
            inputs = inputs.to(device)
            labels = labels.to(device)
 
            # Forward pass
            outputs = model(inputs) #shape: [batch_size, num_classes
    
            scores, preds = torch.max(outputs, 1) #argmax와 같음
 
            softmax = torch.nn.Softmax(dim=1)
            soft_scores = softmax(outputs)
            s_scores, _ = torch.max(soft_scores, 1)
            
            # Collect predictions and labels for metric calculations
            all_labels.extend(labels.cpu().numpy())
            all_preds.extend(preds.cpu().numpy())
            #all_scores.extend(s_scores.cpu().numpy())
 
            # Update the correct and total predictions
            # for label, prediction in zip(labels, preds):
            #     classname = test_loader.dataset.classes[label]
            #     if label == prediction:
            #         correct_pred[classname] += 1
            #     total_pred[classname] += 1
 
            for label, prediction, score in zip(labels, preds, s_scores):
                classname = test_loader.dataset.classes[label]
                if label == prediction:
                    correct_pred[classname] += 1
                    correct_score_pred[classname] += score
                total_pred[classname] += 1
 
    # Calculate average score per class
    avg_score_per_class = {classname: correct_score_pred[classname] / correct_pred[classname] if correct_pred[classname] > 0 else 0
                          for classname in test_loader.dataset.classes}
 
    # Calculate accuracy per class
    accuracy_per_class = {classname: correct_pred[classname] / total_pred[classname] if total_pred[classname] > 0 else 0
                          for classname in test_loader.dataset.classes}
 
    # Calculate overall accuracy
    overall_accuracy = accuracy_score(all_labels, all_preds)
 
    # Print the evaluation results
    print("Accuracy per class:")
    for classname, accuracy in accuracy_per_class.items():
        print(f"{classname}: {accuracy:.4f}")
    
    print("Average score per class:")
    for classname, avg_score in avg_score_per_class.items():
        print(f"{classname}: {avg_score:.4f}")
 
    print()
    print(f"Overall Accuracy: {overall_accuracy:.4f}")
 
    print()
    report = classification_report(all_labels, all_preds, target_names=test_loader.dataset.classes)
    print(report)
 
    # 추가
    # Confusion matrix
    cm = confusion_matrix(all_labels, all_preds)
    plt.figure(figsize=(10,8))
    sns.heatmap(cm, annot=True, fmt='d', xticklabels=test_loader.dataset.classes,
                yticklabels=test_loader.dataset.classes, cmap='Blues')
    plt.xlabel('Predicted')
    plt.ylabel('True')
    plt.title('Confusion Matrix')
    plt.show()
    plt.savefig(save_file_name)

ONNX 변환

ONNX 변환 시 고려해야할 사항

  • Pytorch 에서 제공하는 Resnet18 모델은 동적 입력을 지원하기 위한 AdaptiveAvgPool2d() 함수가 포함되어 있고, 이 레이어는 ONNX에서 /avgpool/GlobalAveragePool 이름의 레이어로 변환된다.
  • 특히 Hailo 등 Embedded AI 컴파일을 목적으로 ONNX 변환 시, 동적 입력(Dynamic Input)과 그에 대응하는 GlobalAveragePool 레이어가 호환이 어려운 경우가 많다.
  • 따라서 ONNX 변환 시, 입력 데이터 사이즈를 고정해야하고 AdaptiveAvgPool2d() 레이어를 고정된 입력 데이터에 맞게 AvgPool2d() 레이어로 변환한 다음, ONNX 변환 하는 것이 바람직하다.

참고