Skip to content

Instantly share code, notes, and snippets.

@ai2ys
Last active January 31, 2025 16:18
Show Gist options
  • Save ai2ys/fc12aeab655de7ca4d3e1d1993ef7fd1 to your computer and use it in GitHub Desktop.
Save ai2ys/fc12aeab655de7ca4d3e1d1993ef7fd1 to your computer and use it in GitHub Desktop.
PyTorch Transfer Learning and partially freezing model

Transfer Learning with PyTorch

Trying out transfer learning with PyTorch. See samples below...they still might contain some errors.

# old approach used for VGG
import logging
import torch
import torch.nn as nn
import torchvision.models as models
from torchvision.models import VGG16_Weights
class Vgg16PartiallyFrozenFeatures(nn.Module):
def __init__(self, num_frozen_blocks=3):
super(Vgg16PartiallyFrozenFeatures, self).__init__()
self.model = models.vgg16(weights=None)
vgg16_pretrained = models.vgg16(weights=VGG16_Weights.IMAGENET1K_V1)
vgg16_features_state_dict = vgg16_pretrained.features.state_dict()
# self.features_frozen = nn.Sequential()
# self.features_trainable = nn.Sequential()
# block_count = 0
# for i, child in enumerate(self.model.features.children()):
# # print(i, child)
# if isinstance(child, nn.MaxPool2d):
# block_count += 1
# if block_count < num_frozen_blocks:
# self.features_frozen.append(child)
# else:
# self.features_trainable.append(child)
# vgg16_features_state_dict = vgg16_pretrained.features.state_dict()
# incompatible_keys = self.features_frozen.load_state_dict(vgg16_features_state_dict, strict=False)
# logging.info("Missing keys:", incompatible_keys.missing_keys)
# logging.info("Unexpected keys:", incompatible_keys.unexpected_keys)
# self.features_frozen.requires_grad_(False)
# self.model.features = nn.Sequential(self.features_frozen, self.features_trainable)
self.frozen_childs = nn.Sequential()
block_count = 0
for i, child in enumerate(self.model.features.children()):
if isinstance(child, nn.MaxPool2d):
block_count += 1
if block_count < num_frozen_blocks:
self.frozen_childs.append(child)
logging.debug(f"freezing child: {child._get_name()}")
vgg16_features_state_dict = vgg16_pretrained.features.state_dict()
incompatible_keys = self.frozen_childs.load_state_dict(vgg16_features_state_dict, strict=False)
self.frozen_childs.requires_grad_(False)
logging.debug(f"Missing keys: {incompatible_keys.missing_keys}")
logging.debug(f"Unexpected keys: {incompatible_keys.unexpected_keys}")
if logging.root.level ==logging.DEBUG:
logging.debug("Comparing pretrained to custom weights - ")
for (kr, vr), (kp, vp) in zip(self.model.state_dict().items(), vgg16_pretrained.state_dict().items()):
logging.debug(f"equal: {kr}, {torch.allclose(vr, vp)}")
# def train(self, mode: bool = True):
# self.model.train(mode)
# self.features_frozen.eval()
# print("requires grad...")
# for i, (name, param) in enumerate(self.model.named_parameters()):
# print(i, name, param.requires_grad)
# print("training...")
# for i, (name, module) in enumerate(self.model.named_modules()):
# print(i, name, module.training)
# # self.features_frozen.requires_grad_(False)
def train(self, mode: bool = True):
self.model.train(mode)
self.frozen_childs.eval()
if logging.root.level == logging.DEBUG:
logging.debug("requires grad...")
for i, (name, param) in enumerate(self.model.named_parameters()):
logging.debug(f"{i}, {name}, {param.requires_grad}")
logging.debug("training...")
for i, (name, module) in enumerate(self.model.named_modules()):
logging.debug(f"{i}, {name}, {module.training}")
def forward(self, x):
return self.model.forward(x)
# %%
logging.basicConfig(level=logging.DEBUG, format='%(asctime)s - %(message)s')
test = Vgg16PartiallyFrozenFeatures()
# %%
test.train()
# %%
import torch
import torch.nn as nn
import torchvision.models as models
from typing import Callable, Optional, List
class TransferLearningResNet(nn.Module):
def __init__(
self,
num_classes: int,
model_fn: Callable[..., models.ResNet] = models.resnet18,
model_weights: Optional[models.Weights] = models.ResNet18_Weights.IMAGENET1K_V1,
freeze_backbone: bool = False,
inference_mode_backbone: bool = False,
):
super(TransferLearningResNet, self).__init__()
self.freeze_backbone = freeze_backbone # Store freeze status
self.inference_mode_backbone = inference_mode_backbone # Store inference mode flag
# Ensure that freezing is only allowed when weights are provided
if model_weights is None and freeze_backbone:
raise ValueError("Cannot freeze the backbone when using randomly initialized weights (weights=None).")
# Load the ResNet model with the specified weights (or None for random init)
self.backbone = model_fn(weights=model_weights)
# Freeze backbone only if pre-trained weights are used
if freeze_backbone:
self._set_requires_grad(self.backbone, requires_grad=False)
# Remove the final classification layer
num_features = self.backbone.fc.in_features
self.backbone.fc = nn.Identity() # Replace the last FC layer with an identity layer
# Custom classifier
self.classifier = nn.Linear(num_features, num_classes)
def forward(self, x):
# Handle backbone behavior
if self.freeze_backbone:
with torch.no_grad(): # Avoids computing gradients for the backbone
if self.inference_mode_backbone:
self.backbone.eval() # Stops BatchNorm & Dropout updates
representations = self.backbone(x)
else:
representations = self.backbone(x)
x = self.classifier(representations) # Classify
return x # Logits output
def _set_requires_grad(self, module: nn.Module, requires_grad: bool):
"""Helper function to set requires_grad for a module."""
for param in module.parameters():
param.requires_grad = requires_grad
def set_trainable_layers(self, layer_names: Optional[List[str]] = None):
"""
Unfreezes specific layers in the backbone and ensures BatchNorm is back in training mode.
Unfreeze the last residual block (e.g., "layer4" in ResNet18) `model.set_trainable_layers(["layer4"])`
Args:
layer_names (List[str], optional): List of layer names to unfreeze.
If None, all layers are unfrozen.
"""
if layer_names is None:
# Unfreeze the entire backbone
self._set_requires_grad(self.backbone, requires_grad=True)
self.freeze_backbone = False
self.backbone.train() # Put BatchNorm back into training mode
else:
# Unfreeze specific layers
for name, module in self.backbone.named_modules():
if any(layer_name in name for layer_name in layer_names):
self._set_requires_grad(module, requires_grad=True)
# Ensure BatchNorm layers in unfrozen blocks return to train mode
for name, module in self.backbone.named_modules():
if isinstance(module, nn.BatchNorm2d) and any(layer_name in name for layer_name in layer_names):
module.train() # Allow BatchNorm to update again
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment