Pytorch Transfer Learning
Transfer Learning refers to the technique of transferring a model pre-trained on a large-scale dataset to a new task with less data for training.
It is one of the most widely used techniques in deep learning practice today β in most cases, transfer learning performs better, trains faster, and requires less data than training from scratch.
* * *
## 1. Core Ideas of Transfer Learning
Features learned by deep neural networks on ImageNet are generalizable:
* Shallow layers: Learn general low-level features (edges, textures, color gradients)
* Middle layers: Learn general mid-level features (shapes, parts, texture combinations)
* Deep layers: Learn task-specific high-level features (faces, wheels, text)
These low and mid-level features are effective for most vision tasks and do not need to be relearned.
### When to Use Transfer Learning?
| Data Amount | Similarity to Source Task | Recommended Strategy |
| --- | --- | --- |
| Small (< 1000) | High | Replace only the final classification head, freeze all backbone |
| Small ( 10000) | Any | Fine-tune all, or consider training from scratch |
### Comparison of Three Core Strategies
The layer structure of a pre-trained model (e.g., ResNet50) is as follows:
* Conv Layer 1~3 (Low-level features: edges/textures): Usually frozen
* Conv Layer 4~6 (Mid-level features: shapes/parts): Optional to freeze
* Conv Layer 7~N (High-level features: semantic information): Fine-tune
* Classifier Head: Replace and train
ββββββββββββββββββββββββββββββββββββββββββββββββββββ Pre-trained Model (e.g., ResNet50) ββ ββββββββββββββββββββββββββββββββββββββββββββ ββ β Conv Layer 1~3 (Low-level: edges/textures) β β Usually frozenβ ββββββββββββββββββββββββββββββββββββββββββββ€ ββ β Conv Layer 4~6 (Mid-level: shapes/parts) β β Optional freezeβ ββββββββββββββββββββββββββββββββββββββββββββ€ ββ β Conv Layer 7~N (High-level: semantic info) β β Fine-tuneβ ββββββββββββββββββββββββββββββββββββββββββββ€ ββ β Classifier Head β β Replace & trainβ ββββββββββββββββββββββββββββββββββββββββββββ ββββββββββββββββββββββββββββββββββββββββββββββββββββ
* * *
## 2. Loading Pre-trained Models
PyTorch provides many official pre-trained models through torchvision.models, and loading them is very simple.
## Example
import torch
import torchvision.models as models
# Load pre-trained model (weights downloaded automatically)
# PyTorch >= 0.13 recommends new syntax: use weights parameter
from torchvision.models import ResNet50_Weights
model = models.resnet50(weights=ResNet50_Weights.IMAGENET1K_V2)
# Old syntax (still works, but will receive deprecation warning)
model = models.resnet50(pretrained=True)
# Don't load pre-trained weights (use network structure only)
model = models.resnet50(weights=None)
### Viewing Model Structure
## Example
# Print full structure
print(model)
# View only last few layers (classification head)
print(model.fc)
# Linear(in_features=2048, out_features=1000, bias=True)
# Count parameters
total_params =sum(p.numel()for p in model.parameters())
trainable_params =sum(p.numel()for p in model.parameters()if p.requires_grad)
print(f"Total parameters: {total_params:,}")
print(f"Trainable parameters: {trainable_params:,}")
### Classification Head Names for Different Models
Different models have different classification head attribute names, so you need to replace the corresponding layer during transfer:
| Model | Classification Head Attribute |
| --- | --- |
| ResNet / RegNet | model.fc |
| VGG / AlexNet | model.classifier |
| DenseNet | model.classifier |
| EfficientNet | model.classifier |
| MobileNetV2/V3 | model.classifier |
| ViT (Vision Transformer) | model.heads.head |
| ConvNeXt | model.classifier |
| Inception V3 | model.fc |
| Swin Transformer | model.head |
* * *
## 3. Three Transfer Strategies
### 3.1 Strategy One: Feature Extraction (Freeze All)
Freeze all parameters of the pre-trained model and only train the newly replaced classification head.
Suitable for scenarios with very little data (a few hundred images), or when the task is highly similar to the source task.
## Example
import torch
import torch.nn as nn
import torchvision.models as models
from torchvision.models import ResNet18_Weights
NUM_CLASSES =5# Number of classes for target task
# Step 1: Load pre-trained model
model = models.resnet18(weights=ResNet18_Weights.IMAGENET1K_V1)
# Step 2: Freeze all parameters
for param in model.parameters():
param.requires_grad=False
# Step 3: Replace classification head (these parameters have requires_grad=True by default)
in_features = model.fc.in_features# 512
model.fc= nn.Linear(in_features, NUM_CLASSES)
# Verify: only classification head is trainable
trainable =[(n, p.shape)for n, p in model.named_parameters()if p.requires_grad]
print(f"Number of trainable layers: {len(trainable)}")
for name, shape in trainable:
print(f" {name}: {shape}")
# Output:
# fc.weight: torch.Size([5, 512])
# fc.bias: torch.Size()
# Step 4: Optimizer only passes trainable parameters (more efficient)
optimizer = torch.optim.Adam(
filter(lambda p: p.requires_grad, model.parameters()),
lr=1e-3
)
# Or equivalent clearer syntax:
optimizer = torch.optim.Adam(model.fc.parameters(), lr=1e-3)
> Feature extraction is the simplest and most commonly used transfer learning strategy, especially suitable for scenarios with limited data.
### 3.2 Strategy Two: Fine-tuning
Unfreeze all or part of the pre-trained layers and train the entire model with a smaller learning rate.
Suitable for scenarios with moderate data amount, or when the task differs somewhat from the source task.
## Example
import torch
import torch.nn as nn
import torchvision.models as models
NUM_CLASSES =10
model = models.resnet50(weights='IMAGENET1K_V2')
# Method A: Full fine-tuning (unfreeze all layers)
# First freeze
for param in model.parameters():
param.requires_grad=False
# Then unfreeze (equivalent to full fine-tuning, this syntax is often used for gradual unfreezing)
for param in model.parameters():
param.requires_grad=True
# Replace classification head
model.fc= nn.Linear(model.fc.in_features, NUM_CLASSES)
# Full fine-tuning: backbone uses small learning rate, head uses large learning rate (see strategy three)
optimizer = torch.optim.SGD(model.parameters(), lr=1e-4, momentum=0.9)
# Method B: Unfreeze last N layers (partial fine-tuning)
model = models.resnet50(weights='IMAGENET1K_V2')
# First freeze all
for param in model.parameters():
param.requires_grad=False
# Only unfreeze layer4 and fc (the last Block and classification head of ResNet)
for param in model.layer4.parameters():
param.requires_grad=True
model.fc= nn.Linear(model.fc.in_features, NUM_CLASSES)# fc is trainable by default
print("Trainable parameters:")
for name, param in model.named_parameters():
if param.requires_grad:
print(f" {name}")
### 3.3 Strategy Three
YouTip