Skip to content

Instantly share code, notes, and snippets.

@sadimanna
Last active December 28, 2021 21:23
Show Gist options
  • Select an option

  • Save sadimanna/1932980184f74f631e3cbce94ba93a5d to your computer and use it in GitHub Desktop.

Select an option

Save sadimanna/1932980184f74f631e3cbce94ba93a5d to your computer and use it in GitHub Desktop.
class new_model(nn.Module):
def __init__(self,output_layer = None):
super().__init__()
self.pretrained = models.resnet18(pretrained=True)
self.output_layer = output_layer
self.layers = list(self.pretrained._modules.keys())
self.layer_count = 0
for l in self.layers:
if l != self.output_layer:
self.layer_count += 1
else:
break
for i in range(1,len(self.layers)-self.layer_count):
self.dummy_var = self.pretrained._modules.pop(self.layers[-i])
self.net = nn.Sequential(self.pretrained._modules)
self.pretrained = None
def forward(self,x):
x = self.net(x)
return x
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment