Skip to content

Instantly share code, notes, and snippets.

@sadimanna
Last active June 30, 2021 10:42
Show Gist options
  • Select an option

  • Save sadimanna/74f9b2e7dbcfae1b9aa1e2a4186c353b to your computer and use it in GitHub Desktop.

Select an option

Save sadimanna/74f9b2e7dbcfae1b9aa1e2a4186c353b to your computer and use it in GitHub Desktop.
class Identity(nn.Module):
def __init__(self):
super(Identity, self).__init__()
def forward(self, x):
return x
class LinearLayer(nn.Module):
def __init__(self,
in_features,
out_features,
use_bias = True,
use_bn = False,
**kwargs):
super(LinearLayer, self).__init__(**kwargs)
self.in_features = in_features
self.out_features = out_features
self.use_bias = use_bias
self.use_bn = use_bn
self.linear = nn.Linear(self.in_features,
self.out_features,
bias = self.use_bias and not self.use_bn)
if self.use_bn:
self.bn = nn.BatchNorm1d(self.out_features)
def forward(self,x):
x = self.linear(x)
if self.use_bn:
x = self.bn(x)
return x
class ProjectionHead(nn.Module):
def __init__(self,
in_features,
hidden_features,
out_features,
head_type = 'nonlinear',
**kwargs):
super(ProjectionHead,self).__init__(**kwargs)
self.in_features = in_features
self.out_features = out_features
self.hidden_features = hidden_features
self.head_type = head_type
if self.head_type == 'linear':
self.layers = LinearLayer(self.in_features,self.out_features,False, True)
elif self.head_type == 'nonlinear':
self.layers = nn.Sequential(
LinearLayer(self.in_features,self.hidden_features,True, True),
nn.ReLU(),
LinearLayer(self.hidden_features,self.out_features,False,True))
def forward(self,x):
x = self.layers(x)
return x
class PreModel(nn.Module):
def __init__(self,base_model):
super().__init__()
self.base_model = base_model
#PRETRAINED MODEL
self.pretrained = models.resnet50(pretrained=True)
self.pretrained.conv1 = nn.Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), bias=False)
self.pretrained.maxpool = Identity()
self.pretrained.fc = Identity()
for p in self.pretrained.parameters():
p.requires_grad = True
self.projector = ProjectionHead(2048, 2048, 128)
def forward(self,x):
out = self.pretrained(x)
xp = self.projector(torch.squeeze(out))
return xp
model = PreModel('resnet50').to('cuda:0')
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment