Last active
April 16, 2019 13:16
-
-
Save ptrblck/f3f2810b7febb7c2b8fe3ee7c51494b2 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from copy import deepcopy | |
import torch | |
import matplotlib | |
matplotlib.use("agg") | |
from torch.backends import cudnn | |
from apex import amp | |
import argparse | |
from torch import cuda | |
from torch import nn | |
from urllib import request | |
import gzip | |
import pickle | |
import os | |
import numpy as np | |
def load(mnist_file): | |
init() | |
with open(mnist_file, 'rb') as f: | |
mnist = pickle.load(f) | |
data_tr = mnist["training_images"].reshape(60000, 1, 28, 28) | |
data_te = mnist["test_images"].reshape(10000, 1, 28, 28) | |
return data_tr, mnist["training_labels"], data_te, mnist["test_labels"] | |
filename = [ | |
["training_images","train-images-idx3-ubyte.gz"], | |
["test_images","t10k-images-idx3-ubyte.gz"], | |
["training_labels","train-labels-idx1-ubyte.gz"], | |
["test_labels","t10k-labels-idx1-ubyte.gz"] | |
] | |
def download_mnist(): | |
base_url = "http://yann.lecun.com/exdb/mnist/" | |
for name in filename: | |
print("Downloading "+name[1]+"...") | |
request.urlretrieve(base_url+name[1], name[1]) | |
print("Download complete.") | |
def save_mnist(): | |
mnist = {} | |
for name in filename[:2]: | |
with gzip.open(name[1], 'rb') as f: | |
mnist[name[0]] = np.frombuffer(f.read(), np.uint8, offset=16).reshape(-1,28*28) | |
for name in filename[-2:]: | |
with gzip.open(name[1], 'rb') as f: | |
mnist[name[0]] = np.frombuffer(f.read(), np.uint8, offset=8) | |
with open("mnist.pkl", 'wb') as f: | |
pickle.dump(mnist,f) | |
print("Save complete.") | |
def init(): | |
if not os.path.isfile("mnist.pkl"): | |
download_mnist() | |
save_mnist() | |
def poly_lr(epoch, max_epochs, initial_lr, exponent=0.9): | |
return initial_lr * (1 - epoch / max_epochs)**exponent | |
class GlobalAveragePool(nn.Module): | |
def forward(self, x): | |
axes = range(2, len(x.shape)) | |
for a in axes[::-1]: | |
x = x.mean(a, keepdim=False) | |
return x | |
def get_default_network_config(): | |
""" | |
returns a dictionary that contains pointers to conv, nonlin and norm ops and the default kwargs I like to use | |
:return: | |
""" | |
props = {} | |
props['conv_op'] = nn.Conv2d | |
props['conv_op_kwargs'] = {'stride': 1, 'dilation': 1, 'bias': True} # kernel size will be set by network! | |
props['nonlin'] = nn.LeakyReLU | |
props['nonlin_kwargs'] = {'negative_slope': 1e-2, 'inplace': True} | |
props['norm_op'] = nn.BatchNorm2d | |
props['norm_op_kwargs'] = {'eps': 1e-5, 'affine': True, 'momentum': 0.1} | |
props['dropout_op'] = nn.Dropout2d | |
props['dropout_op_kwargs'] = {'p': 0.0, 'inplace': True} | |
return props | |
class ConvDropoutNormReLU(nn.Module): | |
def __init__(self, input_channels, output_channels, kernel_size, network_props): | |
""" | |
if network_props['dropout_op'] is None then no dropout | |
if network_props['norm_op'] is None then no norm | |
:param input_channels: | |
:param output_channels: | |
:param kernel_size: | |
:param network_props: | |
""" | |
super(ConvDropoutNormReLU, self).__init__() | |
network_props = deepcopy(network_props) # network_props is a dict and mutable, so we deepcopy to be safe. | |
self.conv = network_props['conv_op'](input_channels, output_channels, kernel_size, | |
padding=[(i - 1) // 2 for i in kernel_size], | |
**network_props['conv_op_kwargs']) | |
# maybe dropout | |
if network_props['dropout_op'] is not None: | |
self.do = network_props['dropout_op'](**network_props['dropout_op_kwargs']) | |
else: | |
self.do = lambda x: x | |
if network_props['norm_op'] is not None: | |
self.norm = network_props['norm_op'](output_channels, **network_props['norm_op_kwargs']) | |
else: | |
self.norm = lambda x: x | |
self.nonlin = network_props['nonlin'](**network_props['nonlin_kwargs']) | |
self.all = nn.Sequential(self.conv, self.do, self.norm, self.nonlin) | |
def forward(self, x): | |
return self.all(x) | |
class StackedConvLayers(nn.Module): | |
def __init__(self, input_channels, output_channels, kernel_size, network_props, num_convs, first_stride=None): | |
""" | |
if network_props['dropout_op'] is None then no dropout | |
if network_props['norm_op'] is None then no norm | |
:param input_channels: | |
:param output_channels: | |
:param kernel_size: | |
:param network_props: | |
""" | |
super(StackedConvLayers, self).__init__() | |
network_props = deepcopy(network_props) # network_props is a dict and mutable, so we deepcopy to be safe. | |
network_props_first = deepcopy(network_props) | |
if first_stride is not None: | |
network_props_first['conv_op_kwargs']['stride'] = first_stride | |
self.convs = nn.Sequential( | |
ConvDropoutNormReLU(input_channels, output_channels, kernel_size, network_props_first), | |
*[ConvDropoutNormReLU(output_channels, output_channels, kernel_size, network_props) for _ in range(num_convs - 1)] | |
) | |
def forward(self, x): | |
return self.convs(x) | |
class SimpleNetwork(nn.Module): | |
def __init__(self, props=None): | |
super(SimpleNetwork, self).__init__() | |
if props is None: | |
props = get_default_network_config() | |
self.stage1 = StackedConvLayers(1, 16, (3, 3), props, 2, 1) | |
self.stage2 = StackedConvLayers(16, 32, (3, 3), props, 2, 2) | |
self.stage3 = StackedConvLayers(32, 64, (3, 3), props, 3, 2) | |
self.stage4 = StackedConvLayers(64, 128, (3, 3), props, 3, 2) | |
self.pool = GlobalAveragePool() | |
self.fc = nn.Linear(128, 10, False) | |
def forward(self, x): | |
return self.fc(self.pool(self.stage4(self.stage3(self.stage2(self.stage1(x)))))) | |
if __name__ == "__main__": | |
parser = argparse.ArgumentParser() | |
parser.add_argument("--seed", type=int, required=False, default=None) | |
parser.add_argument("--test_only", action="store_true", default=False) | |
parser.add_argument("-s", help="output filename for trained model") | |
parser.add_argument("-test_fnames", required=False, nargs='+') | |
args = parser.parse_args() | |
seed = args.seed | |
test_only = args.test_only | |
# seeding | |
np.random.seed(seed) | |
cuda.manual_seed(np.random.randint(10000)) | |
cuda.manual_seed_all(np.random.randint(10000)) | |
cudnn.deterministic = True | |
cudnn.benchmark = False | |
data_tr, target_tr, data_te, target_te = load("mnist.pkl") | |
data_tr = torch.from_numpy(data_tr).float().cuda() | |
target_tr = torch.from_numpy(target_tr).long().cuda() | |
data_te = torch.from_numpy(data_te).float().cuda() | |
target_te = torch.from_numpy(target_te).long().cuda() | |
network = SimpleNetwork().cuda() | |
batch_size = 512 | |
if not test_only: | |
optimizer = torch.optim.Adam(network.parameters(), 1e-3, amsgrad=True, weight_decay=1e-5) | |
network, optimizer = amp.initialize(network, optimizer, opt_level="O1") | |
epochs = 30 | |
loss = torch.nn.CrossEntropyLoss() | |
network.train() | |
for epoch in range(epochs): | |
print(epoch) | |
optimizer.param_groups[0]['lr'] = poly_lr(epoch, epochs, 1e-3, 0.9) | |
for _ in range(60000 // batch_size): | |
optimizer.zero_grad() | |
idxs = np.random.choice(60000, batch_size) | |
data = data_tr[idxs] | |
target = target_tr[idxs] | |
out = network(data) | |
l = loss(out, target) | |
with amp.scale_loss(l, optimizer) as scaled_loss: | |
scaled_loss.backward() | |
optimizer.step() | |
torch.save(network.state_dict(), args.s) | |
with torch.no_grad(): | |
network.eval() | |
out = network(data_te) | |
_, amax = out.max(dim=1) | |
acc = (amax == target_te).float().mean() | |
print("accuracy on test: ", acc) | |
else: | |
if not isinstance(args.test_fnames, list): | |
args.test_fnames = [args.test_fnames] | |
for f in args.test_fnames: | |
network.load_state_dict(torch.load(f, map_location=torch.device('cuda', torch.cuda.current_device()))) | |
network = amp.initialize(network, opt_level="O1") | |
with torch.no_grad(): | |
network.eval() | |
out = network(data_te) | |
_, amax = out.max(dim=1) | |
acc = (amax == target_te).float().mean() | |
print("file", f, "accuracy on test: ", acc) | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment