Created
September 11, 2017 20:50
-
-
Save ikhlestov/8f8745790bfa330a572ceac3b7f8a08a to your computer and use it in GitHub Desktop.
pytorch: mixed models definition
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from torch import nn | |
class Model(nn.Module): | |
def __init__(self): | |
super().__init__() | |
self.feature_extractor = nn.Sequential( | |
nn.Conv2d(3, 12, kernel_size=3, padding=1, stride=1), | |
nn.Conv2d(12, 24, kernel_size=3, padding=1, stride=1), | |
) | |
self.second_extractor = nn.Conv2d( | |
24, 36, kernel_size=3, padding=1, stride=1) | |
def forward(self, x): | |
x = self.feature_extractor(x) | |
x = self.second_extractor(x) | |
# note that we may call same layer twice or mode | |
x = self.second_extractor(x) | |
return x |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
if you call the second_extractor twice, then its shape may not match anymore, is it? Because for the second time, the input would be 24, but it expects 12.