Created
February 18, 2019 23:10
-
-
Save yuikns/e331a2f074ce6c47a49a4470226ba75c to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import math | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
import logging | |
p_dropout = 0.1 | |
class drnet(nn.Module): | |
def __init__(self): | |
super(drnet, self).__init__() | |
self.seq = nn.Sequential( | |
nn.Conv2d( | |
in_channels=1, | |
out_channels=8, | |
kernel_size=2, | |
stride=1, | |
padding=0, | |
), | |
nn.BatchNorm2d(8), | |
nn.RReLU(inplace=True), | |
nn.Conv2d( | |
in_channels=8, | |
out_channels=16, | |
kernel_size=1, | |
stride=1, | |
padding=0, | |
), | |
#c1-s | |
# nn.BatchNorm2d(16), | |
# nn.ReLU(inplace=True), | |
# nn.Dropout(p=0.2), | |
# nn.Conv2d( | |
# in_channels=16, | |
# out_channels=16, | |
# kernel_size=1, | |
# stride=1, | |
# padding=0, | |
# ), | |
#c1-e | |
nn.MaxPool2d(kernel_size=4), | |
# More layers | |
#c2-s | |
# nn.BatchNorm2d(16), | |
# nn.ReLU(inplace=True), | |
# nn.Dropout(p=p_dropout), | |
# nn.Conv2d( | |
# in_channels=16, | |
# out_channels=16, | |
# kernel_size=1, | |
# stride=1, | |
# padding=0, | |
# ), | |
#c2-e | |
nn.BatchNorm2d(16), | |
nn.ReLU(inplace=True), | |
nn.Dropout(p=p_dropout), | |
nn.Conv2d( | |
in_channels=16, | |
out_channels=16, | |
kernel_size=1, | |
stride=1, | |
padding=0, | |
), | |
nn.MaxPool2d(kernel_size=2), | |
# | |
nn.BatchNorm2d(16), # BN | |
nn.ReLU(inplace=True), # ReLU | |
) | |
in_features = int(self.conv(torch.zeros(1, 1, 28,28)).size(1)) | |
out_features = 10 | |
logging.info("initialized cnn.conv, num feature dimension: {}".format( | |
in_features)) | |
self.fc = nn.Sequential( | |
nn.Linear(in_features, out_features), | |
# nn.LogSoftmax(dim=1), | |
nn.Softmax(dim=1), | |
) | |
self.models = {} | |
def conv(self, x): | |
x = self.seq(x) | |
# x = torch.mean(x, dim=2, keepdim=True) | |
x = x.view(x.size(0), -1) | |
# print('#### view:shape: ', x.shape) | |
# view:shape: torch.Size([2880, 14976]) | |
return x | |
def forward(self, x): | |
x = self.conv(x) | |
return self.fc(x) | |
def save(self, key): | |
model = self.state_dict() | |
self.models[key] = model | |
def load(self, key): | |
if key in self.models: | |
self.load_state_dict(self.models[key], strict=True) | |
else: | |
logging.error("key {} not found".format(key)) | |
net = drnet().to(device) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment