Skip to content

Instantly share code, notes, and snippets.

@tuna2134
Last active September 30, 2023 02:23
Show Gist options
  • Save tuna2134/41b5f33672d18b24a377cd950ebb939e to your computer and use it in GitHub Desktop.
Save tuna2134/41b5f33672d18b24a377cd950ebb939e to your computer and use it in GitHub Desktop.

PyTorch CNN(Conv2d解説)

自分のメモのために作りました。

Math

W_OUT = Output width

W_IN = Input width

F_W = Filter width(kernel size)

P = Padding

S = Stride

W_OUT = (W_IN - F_W + P * 2) / S + 1

Sample

import torch.nn as nn
import torch.nn.functional as F


class CNNModel(nn.Module):
    def __init__(self):
        super(CNNModel, self).__init__()
        self.relu = nn.ReLU()
        self.conv1 = nn.Conv2d(3, 16, 3)
        self.conv2 = nn.Conv2d(16, 32, 3)
        self.maxpool = nn.MaxPool2d(2, stride=2)
        self.dropout = nn.Dropout(0.25)
        self.fc1 = nn.Linear(54 * 54 * 32, 128)
        self.fc2 = nn.Linear(128, 2)

    def forward(self, x):
        # Width 224, Height 224
        out = self.conv1(x)
        # Output(width): (224 - 3 + 0 * 2) / 1 + 1 = 222
        out = self.relu(out)
        out = self.maxpool(out)
        # Maxpooling: 222 / 2 = 110
        out = self.conv2(out)
        # Output(width): (110 - 3 + 0 * 2) / 1 + 1 = 108
        out = self.relu(out)
        out = self.maxpool(out)
        # Maxpooling: 108 / 2 = 54
        out = self.dropout(out)
        out = out.view(out.size()[0], -1)
        out = self.fc1(out)
        out = self.relu(out)
        out = self.fc2(out)
        return out
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment