Trying out transfer learning with PyTorch. See samples below...they still might contain some errors.
Last active
January 31, 2025 16:18
-
-
Save ai2ys/fc12aeab655de7ca4d3e1d1993ef7fd1 to your computer and use it in GitHub Desktop.
PyTorch Transfer Learning and partially freezing model
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# old approach used for VGG | |
import logging | |
import torch | |
import torch.nn as nn | |
import torchvision.models as models | |
from torchvision.models import VGG16_Weights | |
class Vgg16PartiallyFrozenFeatures(nn.Module): | |
def __init__(self, num_frozen_blocks=3): | |
super(Vgg16PartiallyFrozenFeatures, self).__init__() | |
self.model = models.vgg16(weights=None) | |
vgg16_pretrained = models.vgg16(weights=VGG16_Weights.IMAGENET1K_V1) | |
vgg16_features_state_dict = vgg16_pretrained.features.state_dict() | |
# self.features_frozen = nn.Sequential() | |
# self.features_trainable = nn.Sequential() | |
# block_count = 0 | |
# for i, child in enumerate(self.model.features.children()): | |
# # print(i, child) | |
# if isinstance(child, nn.MaxPool2d): | |
# block_count += 1 | |
# if block_count < num_frozen_blocks: | |
# self.features_frozen.append(child) | |
# else: | |
# self.features_trainable.append(child) | |
# vgg16_features_state_dict = vgg16_pretrained.features.state_dict() | |
# incompatible_keys = self.features_frozen.load_state_dict(vgg16_features_state_dict, strict=False) | |
# logging.info("Missing keys:", incompatible_keys.missing_keys) | |
# logging.info("Unexpected keys:", incompatible_keys.unexpected_keys) | |
# self.features_frozen.requires_grad_(False) | |
# self.model.features = nn.Sequential(self.features_frozen, self.features_trainable) | |
self.frozen_childs = nn.Sequential() | |
block_count = 0 | |
for i, child in enumerate(self.model.features.children()): | |
if isinstance(child, nn.MaxPool2d): | |
block_count += 1 | |
if block_count < num_frozen_blocks: | |
self.frozen_childs.append(child) | |
logging.debug(f"freezing child: {child._get_name()}") | |
vgg16_features_state_dict = vgg16_pretrained.features.state_dict() | |
incompatible_keys = self.frozen_childs.load_state_dict(vgg16_features_state_dict, strict=False) | |
self.frozen_childs.requires_grad_(False) | |
logging.debug(f"Missing keys: {incompatible_keys.missing_keys}") | |
logging.debug(f"Unexpected keys: {incompatible_keys.unexpected_keys}") | |
if logging.root.level ==logging.DEBUG: | |
logging.debug("Comparing pretrained to custom weights - ") | |
for (kr, vr), (kp, vp) in zip(self.model.state_dict().items(), vgg16_pretrained.state_dict().items()): | |
logging.debug(f"equal: {kr}, {torch.allclose(vr, vp)}") | |
# def train(self, mode: bool = True): | |
# self.model.train(mode) | |
# self.features_frozen.eval() | |
# print("requires grad...") | |
# for i, (name, param) in enumerate(self.model.named_parameters()): | |
# print(i, name, param.requires_grad) | |
# print("training...") | |
# for i, (name, module) in enumerate(self.model.named_modules()): | |
# print(i, name, module.training) | |
# # self.features_frozen.requires_grad_(False) | |
def train(self, mode: bool = True): | |
self.model.train(mode) | |
self.frozen_childs.eval() | |
if logging.root.level == logging.DEBUG: | |
logging.debug("requires grad...") | |
for i, (name, param) in enumerate(self.model.named_parameters()): | |
logging.debug(f"{i}, {name}, {param.requires_grad}") | |
logging.debug("training...") | |
for i, (name, module) in enumerate(self.model.named_modules()): | |
logging.debug(f"{i}, {name}, {module.training}") | |
def forward(self, x): | |
return self.model.forward(x) | |
# %% | |
logging.basicConfig(level=logging.DEBUG, format='%(asctime)s - %(message)s') | |
test = Vgg16PartiallyFrozenFeatures() | |
# %% | |
test.train() | |
# %% |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import torch.nn as nn | |
import torchvision.models as models | |
from typing import Callable, Optional, List | |
class TransferLearningResNet(nn.Module): | |
def __init__( | |
self, | |
num_classes: int, | |
model_fn: Callable[..., models.ResNet] = models.resnet18, | |
model_weights: Optional[models.Weights] = models.ResNet18_Weights.IMAGENET1K_V1, | |
freeze_backbone: bool = False, | |
inference_mode_backbone: bool = False, | |
): | |
super(TransferLearningResNet, self).__init__() | |
self.freeze_backbone = freeze_backbone # Store freeze status | |
self.inference_mode_backbone = inference_mode_backbone # Store inference mode flag | |
# Ensure that freezing is only allowed when weights are provided | |
if model_weights is None and freeze_backbone: | |
raise ValueError("Cannot freeze the backbone when using randomly initialized weights (weights=None).") | |
# Load the ResNet model with the specified weights (or None for random init) | |
self.backbone = model_fn(weights=model_weights) | |
# Freeze backbone only if pre-trained weights are used | |
if freeze_backbone: | |
self._set_requires_grad(self.backbone, requires_grad=False) | |
# Remove the final classification layer | |
num_features = self.backbone.fc.in_features | |
self.backbone.fc = nn.Identity() # Replace the last FC layer with an identity layer | |
# Custom classifier | |
self.classifier = nn.Linear(num_features, num_classes) | |
def forward(self, x): | |
# Handle backbone behavior | |
if self.freeze_backbone: | |
with torch.no_grad(): # Avoids computing gradients for the backbone | |
if self.inference_mode_backbone: | |
self.backbone.eval() # Stops BatchNorm & Dropout updates | |
representations = self.backbone(x) | |
else: | |
representations = self.backbone(x) | |
x = self.classifier(representations) # Classify | |
return x # Logits output | |
def _set_requires_grad(self, module: nn.Module, requires_grad: bool): | |
"""Helper function to set requires_grad for a module.""" | |
for param in module.parameters(): | |
param.requires_grad = requires_grad | |
def set_trainable_layers(self, layer_names: Optional[List[str]] = None): | |
""" | |
Unfreezes specific layers in the backbone and ensures BatchNorm is back in training mode. | |
Unfreeze the last residual block (e.g., "layer4" in ResNet18) `model.set_trainable_layers(["layer4"])` | |
Args: | |
layer_names (List[str], optional): List of layer names to unfreeze. | |
If None, all layers are unfrozen. | |
""" | |
if layer_names is None: | |
# Unfreeze the entire backbone | |
self._set_requires_grad(self.backbone, requires_grad=True) | |
self.freeze_backbone = False | |
self.backbone.train() # Put BatchNorm back into training mode | |
else: | |
# Unfreeze specific layers | |
for name, module in self.backbone.named_modules(): | |
if any(layer_name in name for layer_name in layer_names): | |
self._set_requires_grad(module, requires_grad=True) | |
# Ensure BatchNorm layers in unfrozen blocks return to train mode | |
for name, module in self.backbone.named_modules(): | |
if isinstance(module, nn.BatchNorm2d) and any(layer_name in name for layer_name in layer_names): | |
module.train() # Allow BatchNorm to update again |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment