Skip to content

Instantly share code, notes, and snippets.

@yuikns
Created February 18, 2019 23:10
Show Gist options
  • Save yuikns/e331a2f074ce6c47a49a4470226ba75c to your computer and use it in GitHub Desktop.
Save yuikns/e331a2f074ce6c47a49a4470226ba75c to your computer and use it in GitHub Desktop.
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
import logging
p_dropout = 0.1
class drnet(nn.Module):
def __init__(self):
super(drnet, self).__init__()
self.seq = nn.Sequential(
nn.Conv2d(
in_channels=1,
out_channels=8,
kernel_size=2,
stride=1,
padding=0,
),
nn.BatchNorm2d(8),
nn.RReLU(inplace=True),
nn.Conv2d(
in_channels=8,
out_channels=16,
kernel_size=1,
stride=1,
padding=0,
),
#c1-s
# nn.BatchNorm2d(16),
# nn.ReLU(inplace=True),
# nn.Dropout(p=0.2),
# nn.Conv2d(
# in_channels=16,
# out_channels=16,
# kernel_size=1,
# stride=1,
# padding=0,
# ),
#c1-e
nn.MaxPool2d(kernel_size=4),
# More layers
#c2-s
# nn.BatchNorm2d(16),
# nn.ReLU(inplace=True),
# nn.Dropout(p=p_dropout),
# nn.Conv2d(
# in_channels=16,
# out_channels=16,
# kernel_size=1,
# stride=1,
# padding=0,
# ),
#c2-e
nn.BatchNorm2d(16),
nn.ReLU(inplace=True),
nn.Dropout(p=p_dropout),
nn.Conv2d(
in_channels=16,
out_channels=16,
kernel_size=1,
stride=1,
padding=0,
),
nn.MaxPool2d(kernel_size=2),
#
nn.BatchNorm2d(16), # BN
nn.ReLU(inplace=True), # ReLU
)
in_features = int(self.conv(torch.zeros(1, 1, 28,28)).size(1))
out_features = 10
logging.info("initialized cnn.conv, num feature dimension: {}".format(
in_features))
self.fc = nn.Sequential(
nn.Linear(in_features, out_features),
# nn.LogSoftmax(dim=1),
nn.Softmax(dim=1),
)
self.models = {}
def conv(self, x):
x = self.seq(x)
# x = torch.mean(x, dim=2, keepdim=True)
x = x.view(x.size(0), -1)
# print('#### view:shape: ', x.shape)
# view:shape: torch.Size([2880, 14976])
return x
def forward(self, x):
x = self.conv(x)
return self.fc(x)
def save(self, key):
model = self.state_dict()
self.models[key] = model
def load(self, key):
if key in self.models:
self.load_state_dict(self.models[key], strict=True)
else:
logging.error("key {} not found".format(key))
net = drnet().to(device)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment