Created
February 17, 2021 17:39
-
-
Save willprice/efc1e390c06b50b74e63331b0cf92e80 to your computer and use it in GitHub Desktop.
A little helper for finetuning networks
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from torch import nn | |
import torch | |
from typing import List | |
def filter_parameters_for_finetuning(module: nn.Module) -> List[torch.Tensor]: | |
""" | |
Args: | |
module: A :py:class:`nn.Module` object where some of the children may have | |
a boolean attribute ``finetune``, which if it exists and is ``False``, | |
will exclude parameters from this submodule from the result. | |
Returns: | |
A list of parameters in the module for finetuning. | |
""" | |
params = [] | |
# We're going to look at each direct child of the current module and recurse into | |
# those that don't have ``finetune=False``. | |
for child in module.children(): | |
if hasattr(child, 'finetune') and not child.finetune: | |
# We don't recurse into this part of the module subtree since we want to | |
# freeze all these parameters | |
continue | |
# but if the child is going to be finetuned, we need to add all parameters | |
# declared in the child | |
params.extend(child.parameters(recurse=False)) | |
# and all of its children which also have ``finetune=True`` | |
params.extend(filter_parameters_for_finetuning(child)) | |
return params | |
def demo(): | |
class SubSubSubModule(nn.Module): | |
def __init__(self): | |
super().__init__() | |
self.lin = nn.Linear(5, 7) | |
class SubSubModule(nn.Module): | |
def __init__(self): | |
super().__init__() | |
self.lin = nn.Linear(10, 15) | |
self.m = SubSubSubModule() | |
self.finetune = False | |
class SubModule(nn.Module): | |
def __init__(self): | |
super().__init__() | |
self.lin = nn.Linear(20, 30) | |
self.m = SubSubModule() | |
class Net(nn.Module): | |
def __init__(self): | |
super().__init__() | |
self.lin = nn.Linear(40, 60) | |
self.m = SubModule() | |
net = Net() | |
print([ | |
param.shape for param in filter_parameters_for_finetuning(net) | |
]) | |
# Outputs | |
# [torch.Size([60, 40]), <- Net.lin.weight | |
# torch.Size([60]), <- Net.lin.bias | |
# torch.Size([30, 20]), <- Net.m.lin.weight | |
# torch.Size([30])] <- Net.m.lin.bias | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Currently this doesn't support overriding finetuning of parts of a submodule which has
finetune=False
, e.g. ifSubSbuSubModule
hasfinetune=True
it would still be excluded.