自分のメモのために作りました。
W_OUT
= Output width
W_IN
= Input width
F_W
= Filter width(kernel size)
P
= Padding
S
= Stride
W_OUT = (W_IN - F_W + P * 2) / S + 1
import torch.nn as nn
import torch.nn.functional as F
class CNNModel(nn.Module):
def __init__(self):
super(CNNModel, self).__init__()
self.relu = nn.ReLU()
self.conv1 = nn.Conv2d(3, 16, 3)
self.conv2 = nn.Conv2d(16, 32, 3)
self.maxpool = nn.MaxPool2d(2, stride=2)
self.dropout = nn.Dropout(0.25)
self.fc1 = nn.Linear(54 * 54 * 32, 128)
self.fc2 = nn.Linear(128, 2)
def forward(self, x):
# Width 224, Height 224
out = self.conv1(x)
# Output(width): (224 - 3 + 0 * 2) / 1 + 1 = 222
out = self.relu(out)
out = self.maxpool(out)
# Maxpooling: 222 / 2 = 110
out = self.conv2(out)
# Output(width): (110 - 3 + 0 * 2) / 1 + 1 = 108
out = self.relu(out)
out = self.maxpool(out)
# Maxpooling: 108 / 2 = 54
out = self.dropout(out)
out = out.view(out.size()[0], -1)
out = self.fc1(out)
out = self.relu(out)
out = self.fc2(out)
return out