Last active
November 5, 2018 23:45
-
-
Save Redchards/65f1a6f758a1a5c5efb56f83933c3f6e to your computer and use it in GitHub Desktop.
pytorch implementation of highway networks
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python3 | |
# -*- coding: utf-8 -*- | |
""" | |
Created on Mon Nov 5 17:22:52 2018 | |
@author: Vladslinger """ | |
import torch | |
from torchvision import datasets, transforms | |
def generate_linear_layers(in_size, out_size, layer_count): | |
return [torch.nn.Linear(in_size, in_size) for _ in range(layer_count)] | |
class HighwayNetwork(torch.nn.Module): | |
def __init__(self, in_size, out_size, layer_count, nonlinear_function=torch.nn.Sigmoid(), activation=torch.nn.ReLU(), bias=-1.0): | |
super(HighwayNetwork, self).__init__() | |
self.carry_gate_list = torch.nn.ModuleList(generate_linear_layers(in_size, in_size, layer_count)) | |
self.linear_term_list = torch.nn.ModuleList(generate_linear_layers(in_size, in_size, layer_count)) | |
self.nonlinear_function = nonlinear_function | |
self.out_size = out_size | |
self.activation = activation | |
self.final_layer = torch.nn.Linear(in_size, out_size) | |
self.output_function = torch.nn.Softmax() | |
for carry_gate in self.carry_gate_list : | |
carry_gate.bias.data.fill_(bias) | |
'''self.fc1 = torch.nn.Linear(in_size, 500) | |
self.relu = torch.nn.ReLU() | |
self.fc2 = torch.nn.Linear(500, 10)''' | |
'''def forward(self, x): | |
out = self.fc1(x) | |
out = self.relu(out) | |
out = self.fc2(out) | |
return out | |
''' | |
def forward(self, x): | |
out = x | |
for carry_gate, linear_term in zip(self.carry_gate_list, self.linear_term_list): | |
gate = self.nonlinear_function(carry_gate(out)) | |
out = gate * self.activation(linear_term(out)) + (1.0 - gate) * out | |
#out = self.activation(linear_term(out)) | |
out = self.final_layer(out) | |
#out = self.output_function(out) | |
return out | |
if __name__ == '__main__': | |
batch_size = 64 | |
nb_digits: int = 10 | |
train_loader = torch.utils.data.DataLoader(datasets.MNIST('../data', train=True, download=True, | |
transform=transforms.Compose([transforms.ToTensor(), | |
transforms.Normalize( | |
(0.1307,), | |
(0.3081,))])), | |
batch_size=batch_size, shuffle=True) | |
test_loader = torch.utils.data.DataLoader(datasets.MNIST('../data', train=False, download=True, | |
transform=transforms.Compose([transforms.ToTensor(), | |
transforms.Normalize( | |
(0.1307,), | |
(0.3081,))])), | |
batch_size=batch_size, shuffle=True) | |
print(train_loader.dataset.train_data.size()) | |
y_onehot = torch.FloatTensor(batch_size, nb_digits) | |
model = HighwayNetwork(28 * 28, 10, 25) | |
loss = torch.nn.CrossEntropyLoss() | |
learning_rate = 0.0001 | |
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate) | |
for i, (data, target) in enumerate(train_loader): | |
data = data.reshape(-1, 28 * 28) | |
#forward_pass = torch.nn.Softmax()(torch.nn.Linear(28 * 28, nb_digits)(model.forward(data))) | |
forward_pass = model.forward(data) | |
err = loss(forward_pass, target) | |
err.backward() | |
#for param in model.parameters(): | |
# print(param.grad.data.sum()) | |
#print(model.linear_term_list[0].weight) | |
optimizer.step() | |
acc = sum([1 if forward_pass[i].max(0)[1] == target[i] else 0 for i in range(forward_pass.shape[0])]) / forward_pass.shape[0] | |
print("Epoch {} : Loss {:.4f}".format(i, err.mean().item())) | |
print("Accuracy {}%".format(acc * 100)) | |
#print(list(model.parameters())[0].grad) | |
for i, (data, target) in enumerate(train_loader): | |
print(data.shape) | |
data = data.reshape(-1, 28 * 28) | |
#forward_pass = torch.nn.Softmax()(torch.nn.Linear(28 * 28, nb_digits)(model.forward(data))) | |
forward_pass = model.forward(data) | |
#for param in model.parameters(): | |
# print(param.grad.data.sum()) | |
#print(model.linear_term_list[0].weight) | |
print(forward_pass[0].max(0)[1], target[0]) | |
acc = sum([1 if forward_pass[i].max(0)[1] == target[i] else 0 for i in range(forward_pass.shape[0])]) / forward_pass.shape[0] | |
print("Epoch {} : Loss {:.4f}".format(i, err.mean().item())) | |
print("Accuracy {}".format(acc)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment