Skip to content

Instantly share code, notes, and snippets.

@sadimanna
Created January 6, 2021 06:21
Show Gist options
  • Select an option

  • Save sadimanna/33824f6965fc46396ecfd0cbca5a6902 to your computer and use it in GitHub Desktop.

Select an option

Save sadimanna/33824f6965fc46396ecfd0cbca5a6902 to your computer and use it in GitHub Desktop.
class IntResNet(ResNet):
def __init__(self,output_layer,*args):
self.output_layer = output_layer
super().__init__(*args)
self._layers = []
for l in list(self._modules.keys()):
self._layers.append(l)
if l == output_layer:
break
self.layers = OrderedDict(zip(self._layers,[getattr(self,l) for l in self._layers]))
def _forward_impl(self, x):
for l in self._layers:
x = self.layers[l](x)
return x
def forward(self, x):
return self._forward_impl(x)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment