Skip to content

Instantly share code, notes, and snippets.

@sadimanna
Created December 20, 2020 19:28
Show Gist options
  • Select an option

  • Save sadimanna/6b742f2d20cf09916564a362920de91a to your computer and use it in GitHub Desktop.

Select an option

Save sadimanna/6b742f2d20cf09916564a362920de91a to your computer and use it in GitHub Desktop.
class new_model(nn.Module):
def __init__(self, output_layer):
super().__init__()
self.output_layer = output_layer
self.pretrained = models.resnet18(pretrained=True)
self.children_list = []
for n,c in self.pretrained.named_children():
self.children_list.append(c)
if n == self.output_layer:
break
self.net = nn.Sequential(*self.children_list)
self.pretrained = None
def forward(self,x):
x = self.net(x)
return x
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment