Skip to content

Instantly share code, notes, and snippets.

@RicherMans
Created April 17, 2019 11:08
Show Gist options
  • Select an option

  • Save RicherMans/6c38249d582b23f1ee7ed91d7d7894ca to your computer and use it in GitHub Desktop.

Select an option

Save RicherMans/6c38249d582b23f1ee7ed91d7d7894ca to your computer and use it in GitHub Desktop.
Pretrained Resnet used for SED
class PretrainedResNet(torch.nn.Module):
"""Docstring for PretrainedResNet. """
def __init__(self, outputdim):
"""TODO: to be defined1.
:outputdim: TODO
"""
torch.nn.Module.__init__(self)
import pretorched # HUIIIIIII
self._outputdim = outputdim
model_name = 'resnet3d50'
self.net = pretorched.__dict__[model_name](
num_classes=339, pretrained='moments')
self.net.last_linear = nn.Linear(2048, outputdim)
self.net.avgpool = nn.AdaptiveAvgPool3d((None, 1, 1))
def forward(self, x):
x = self.net.conv1(x)
x = self.net.bn1(x)
x = self.net.relu(x)
x = self.net.maxpool(x)
x = self.net.layer1(x)
x = self.net.layer2(x)
x = self.net.layer3(x)
x = self.net.layer4(x)
x = self.net.avgpool(x)
x = x.transpose(1, 2).contiguous()
x = x.view(x.shape[0], x.shape[1], -1)
return self.net.last_linear(x)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment