Skip to content

Instantly share code, notes, and snippets.

@ptrblck
Last active April 16, 2019 13:16
Show Gist options
  • Save ptrblck/f3f2810b7febb7c2b8fe3ee7c51494b2 to your computer and use it in GitHub Desktop.
Save ptrblck/f3f2810b7febb7c2b8fe3ee7c51494b2 to your computer and use it in GitHub Desktop.
from copy import deepcopy
import torch
import matplotlib
matplotlib.use("agg")
from torch.backends import cudnn
from apex import amp
import argparse
from torch import cuda
from torch import nn
from urllib import request
import gzip
import pickle
import os
import numpy as np
def load(mnist_file):
init()
with open(mnist_file, 'rb') as f:
mnist = pickle.load(f)
data_tr = mnist["training_images"].reshape(60000, 1, 28, 28)
data_te = mnist["test_images"].reshape(10000, 1, 28, 28)
return data_tr, mnist["training_labels"], data_te, mnist["test_labels"]
filename = [
["training_images","train-images-idx3-ubyte.gz"],
["test_images","t10k-images-idx3-ubyte.gz"],
["training_labels","train-labels-idx1-ubyte.gz"],
["test_labels","t10k-labels-idx1-ubyte.gz"]
]
def download_mnist():
base_url = "http://yann.lecun.com/exdb/mnist/"
for name in filename:
print("Downloading "+name[1]+"...")
request.urlretrieve(base_url+name[1], name[1])
print("Download complete.")
def save_mnist():
mnist = {}
for name in filename[:2]:
with gzip.open(name[1], 'rb') as f:
mnist[name[0]] = np.frombuffer(f.read(), np.uint8, offset=16).reshape(-1,28*28)
for name in filename[-2:]:
with gzip.open(name[1], 'rb') as f:
mnist[name[0]] = np.frombuffer(f.read(), np.uint8, offset=8)
with open("mnist.pkl", 'wb') as f:
pickle.dump(mnist,f)
print("Save complete.")
def init():
if not os.path.isfile("mnist.pkl"):
download_mnist()
save_mnist()
def poly_lr(epoch, max_epochs, initial_lr, exponent=0.9):
return initial_lr * (1 - epoch / max_epochs)**exponent
class GlobalAveragePool(nn.Module):
def forward(self, x):
axes = range(2, len(x.shape))
for a in axes[::-1]:
x = x.mean(a, keepdim=False)
return x
def get_default_network_config():
"""
returns a dictionary that contains pointers to conv, nonlin and norm ops and the default kwargs I like to use
:return:
"""
props = {}
props['conv_op'] = nn.Conv2d
props['conv_op_kwargs'] = {'stride': 1, 'dilation': 1, 'bias': True} # kernel size will be set by network!
props['nonlin'] = nn.LeakyReLU
props['nonlin_kwargs'] = {'negative_slope': 1e-2, 'inplace': True}
props['norm_op'] = nn.BatchNorm2d
props['norm_op_kwargs'] = {'eps': 1e-5, 'affine': True, 'momentum': 0.1}
props['dropout_op'] = nn.Dropout2d
props['dropout_op_kwargs'] = {'p': 0.0, 'inplace': True}
return props
class ConvDropoutNormReLU(nn.Module):
def __init__(self, input_channels, output_channels, kernel_size, network_props):
"""
if network_props['dropout_op'] is None then no dropout
if network_props['norm_op'] is None then no norm
:param input_channels:
:param output_channels:
:param kernel_size:
:param network_props:
"""
super(ConvDropoutNormReLU, self).__init__()
network_props = deepcopy(network_props) # network_props is a dict and mutable, so we deepcopy to be safe.
self.conv = network_props['conv_op'](input_channels, output_channels, kernel_size,
padding=[(i - 1) // 2 for i in kernel_size],
**network_props['conv_op_kwargs'])
# maybe dropout
if network_props['dropout_op'] is not None:
self.do = network_props['dropout_op'](**network_props['dropout_op_kwargs'])
else:
self.do = lambda x: x
if network_props['norm_op'] is not None:
self.norm = network_props['norm_op'](output_channels, **network_props['norm_op_kwargs'])
else:
self.norm = lambda x: x
self.nonlin = network_props['nonlin'](**network_props['nonlin_kwargs'])
self.all = nn.Sequential(self.conv, self.do, self.norm, self.nonlin)
def forward(self, x):
return self.all(x)
class StackedConvLayers(nn.Module):
def __init__(self, input_channels, output_channels, kernel_size, network_props, num_convs, first_stride=None):
"""
if network_props['dropout_op'] is None then no dropout
if network_props['norm_op'] is None then no norm
:param input_channels:
:param output_channels:
:param kernel_size:
:param network_props:
"""
super(StackedConvLayers, self).__init__()
network_props = deepcopy(network_props) # network_props is a dict and mutable, so we deepcopy to be safe.
network_props_first = deepcopy(network_props)
if first_stride is not None:
network_props_first['conv_op_kwargs']['stride'] = first_stride
self.convs = nn.Sequential(
ConvDropoutNormReLU(input_channels, output_channels, kernel_size, network_props_first),
*[ConvDropoutNormReLU(output_channels, output_channels, kernel_size, network_props) for _ in range(num_convs - 1)]
)
def forward(self, x):
return self.convs(x)
class SimpleNetwork(nn.Module):
def __init__(self, props=None):
super(SimpleNetwork, self).__init__()
if props is None:
props = get_default_network_config()
self.stage1 = StackedConvLayers(1, 16, (3, 3), props, 2, 1)
self.stage2 = StackedConvLayers(16, 32, (3, 3), props, 2, 2)
self.stage3 = StackedConvLayers(32, 64, (3, 3), props, 3, 2)
self.stage4 = StackedConvLayers(64, 128, (3, 3), props, 3, 2)
self.pool = GlobalAveragePool()
self.fc = nn.Linear(128, 10, False)
def forward(self, x):
return self.fc(self.pool(self.stage4(self.stage3(self.stage2(self.stage1(x))))))
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--seed", type=int, required=False, default=None)
parser.add_argument("--test_only", action="store_true", default=False)
parser.add_argument("-s", help="output filename for trained model")
parser.add_argument("-test_fnames", required=False, nargs='+')
args = parser.parse_args()
seed = args.seed
test_only = args.test_only
# seeding
np.random.seed(seed)
cuda.manual_seed(np.random.randint(10000))
cuda.manual_seed_all(np.random.randint(10000))
cudnn.deterministic = True
cudnn.benchmark = False
data_tr, target_tr, data_te, target_te = load("mnist.pkl")
data_tr = torch.from_numpy(data_tr).float().cuda()
target_tr = torch.from_numpy(target_tr).long().cuda()
data_te = torch.from_numpy(data_te).float().cuda()
target_te = torch.from_numpy(target_te).long().cuda()
network = SimpleNetwork().cuda()
batch_size = 512
if not test_only:
optimizer = torch.optim.Adam(network.parameters(), 1e-3, amsgrad=True, weight_decay=1e-5)
network, optimizer = amp.initialize(network, optimizer, opt_level="O1")
epochs = 30
loss = torch.nn.CrossEntropyLoss()
network.train()
for epoch in range(epochs):
print(epoch)
optimizer.param_groups[0]['lr'] = poly_lr(epoch, epochs, 1e-3, 0.9)
for _ in range(60000 // batch_size):
optimizer.zero_grad()
idxs = np.random.choice(60000, batch_size)
data = data_tr[idxs]
target = target_tr[idxs]
out = network(data)
l = loss(out, target)
with amp.scale_loss(l, optimizer) as scaled_loss:
scaled_loss.backward()
optimizer.step()
torch.save(network.state_dict(), args.s)
with torch.no_grad():
network.eval()
out = network(data_te)
_, amax = out.max(dim=1)
acc = (amax == target_te).float().mean()
print("accuracy on test: ", acc)
else:
if not isinstance(args.test_fnames, list):
args.test_fnames = [args.test_fnames]
for f in args.test_fnames:
network.load_state_dict(torch.load(f, map_location=torch.device('cuda', torch.cuda.current_device())))
network = amp.initialize(network, opt_level="O1")
with torch.no_grad():
network.eval()
out = network(data_te)
_, amax = out.max(dim=1)
acc = (amax == target_te).float().mean()
print("file", f, "accuracy on test: ", acc)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment