Created
March 6, 2022 03:53
-
-
Save andreaschandra/84a46bdb5374252272b7e8fb2cc982be to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class BasicBlock(nn.Module): | |
expansion: int = 1 | |
def __init__( | |
self, | |
inplanes: int, | |
planes: int, | |
stride: int = 1, | |
downsample: Optional[nn.Module] = None, | |
groups: int = 1, | |
base_width: int = 64, | |
dilation: int = 1, | |
norm_layer: Optional[Callable[..., nn.Module]] = None, | |
) -> None: | |
super().__init__() | |
if norm_layer is None: | |
norm_layer = nn.BatchNorm2d | |
if groups != 1 or base_width != 64: | |
raise ValueError("BasicBlock only supports groups=1 and base_width=64") | |
if dilation > 1: | |
raise NotImplementedError("Dilation > 1 not supported in BasicBlock") | |
# Both self.conv1 and self.downsample layers downsample the input when stride != 1 | |
self.conv1 = conv3x3(inplanes, planes, stride) | |
self.bn1 = norm_layer(planes) | |
self.relu = nn.ReLU(inplace=True) | |
self.conv2 = conv3x3(planes, planes) | |
self.bn2 = norm_layer(planes) | |
self.downsample = downsample | |
self.stride = stride | |
def forward(self, x: Tensor) -> Tensor: | |
identity = x | |
out = self.conv1(x) | |
out = self.bn1(out) | |
out = self.relu(out) | |
out = self.conv2(out) | |
out = self.bn2(out) | |
if self.downsample is not None: | |
identity = self.downsample(x) | |
out += identity | |
out = self.relu(out) | |
return out |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment