Created
July 16, 2019 23:10
-
-
Save Hanrui-Wang/65b54da534a14583dae1acf0e2c636d7 to your computer and use it in GitHub Desktop.
Basic Pytorch net definition
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| class Net(nn.Module): | |
| def __init__(self): | |
| super(Net, self).__init__() | |
| # 1 input image channel, 6 output channels, 3x3 square convolution | |
| # kernel | |
| self.conv1 = nn.Conv2d(1, 6, 3) | |
| self.conv2 = nn.Conv2d(6, 16, 3) | |
| # an affine operation: y = Wx + b | |
| self.fc1 = nn.Linear(16 * 6 * 6, 120) # 6*6 from image dimension | |
| self.fc2 = nn.Linear(120, 84) | |
| self.fc3 = nn.Linear(84, 10) | |
| def forward(self, x): | |
| # Max pooling over a (2, 2) window | |
| x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2)) | |
| # If the size is a square you can only specify a single number | |
| x = F.max_pool2d(F.relu(self.conv2(x)), 2) | |
| x = x.view(-1, self.num_flat_features(x)) | |
| x = F.relu(self.fc1(x)) | |
| x = F.relu(self.fc2(x)) | |
| x = self.fc3(x) | |
| return x | |
| def num_flat_features(self, x): | |
| size = x.size()[1:] # all dimensions except the batch dimension | |
| num_features = 1 | |
| for s in size: | |
| num_features *= s | |
| return num_features | |
| net = Net() | |
| print(net) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment