Skip to content

Instantly share code, notes, and snippets.

@sadimanna
Created June 30, 2021 11:07
Show Gist options
  • Save sadimanna/fb10640660d76da9216701a0a70938ce to your computer and use it in GitHub Desktop.
Save sadimanna/fb10640660d76da9216701a0a70938ce to your computer and use it in GitHub Desktop.
class DSModel(nn.Module):
def __init__(self,premodel,num_classes):
super().__init__()
self.premodel = premodel
self.num_classes = num_classes
for p in self.premodel.parameters():
p.requires_grad = False
for p in self.premodel.projector.parameters():
p.requires_grad = False
self.lastlayer = nn.Linear(2048,self.num_classes)
def forward(self,x):
out = self.premodel.pretrained(x)
out = self.lastlayer(out)
return out
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment