Last active
September 6, 2017 09:21
-
-
Save czs0x55aa/e40bb2695713b5b0c8eda22ee61e17af to your computer and use it in GitHub Desktop.
use rnn to fit addition operator
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# coding=utf8 | |
import random | |
import torch | |
import torch.nn as nn | |
from torch import optim | |
from torch.autograd import Variable | |
batch_size = 64 | |
n_epochs = 2000 | |
learning_rate = 0.001 | |
n_class = 10 # number 0 - 9 | |
num_embeddings = n_class | |
embedding_dim = 32 | |
input_size = 2 * embedding_dim | |
hidden_size = 32 | |
output_size = n_class | |
max_length = 12 | |
dropout = 0.2 | |
evaluate_every = 100 | |
class RNNModel(nn.Module): | |
def __init__(self): | |
super(RNNModel, self).__init__() | |
self.embedding = nn.Embedding(num_embeddings, embedding_dim) | |
self.rnn = nn.RNN(2 * embedding_dim, hidden_size, dropout=dropout) | |
self.linear = nn.Linear(hidden_size, output_size) | |
def forward(self, input_var): | |
# input_var size (batch_size , max_length , 2) | |
embedded_a = self.embedding(input_var[:, :, 0]) | |
embedded_b = self.embedding(input_var[:, :, 1]) | |
# embedded size (batch_size * max_length, 2 * hidden_size) | |
embedded = torch.cat((embedded_a, embedded_b), dim=2) | |
batch_size = embedded.size(0) | |
outputs, hidden = self.rnn(embedded.transpose(0, 1), None) | |
# outputs size (max_length, batch_size, output_size) | |
outputs = self.linear(outputs.view(-1, hidden_size)).view(-1, batch_size, output_size) | |
return outputs | |
def train(): | |
model = RNNModel() | |
print(model) | |
model_optimizer = optim.Adam(model.parameters(), lr=learning_rate) | |
criterion = nn.CrossEntropyLoss() | |
epoch = 0 | |
while epoch < n_epochs: | |
epoch += 1 | |
# zero gradients | |
model_optimizer.zero_grad() | |
batch_input_var, batch_target_var = batch_generater(batch_size) | |
outputs = model(batch_input_var) | |
batch_target_var = batch_target_var.transpose(0, 1) | |
loss = 0 | |
for i in range(10): | |
loss += criterion(outputs[i], batch_target_var[i]) | |
loss.backward() | |
model_optimizer.step() | |
if epoch % evaluate_every == 0: | |
evaluate(model) | |
print('epoch: %d loss: %.4f' % (epoch, loss.data[0])) | |
def evaluate(model): | |
model.train(False) | |
input_var, target_var = batch_generater(1) | |
outputs = model(input_var) | |
input = input_var.squeeze(0).transpose(0, 1) | |
print('input:') | |
print(input) | |
print('target:') | |
print(target_var) | |
print('predict:') | |
predict = outputs.squeeze(1) | |
topv, topi = predict.topk(1, dim=1) | |
print(topi.transpose(0, 1)) | |
model.train(True) | |
def batch_generater(batch_size=10, max_value=1000000000): | |
batch_input = Variable(torch.LongTensor(batch_size, max_length, 2)) | |
batch_target = Variable(torch.LongTensor(batch_size, max_length)) | |
for i in range(batch_size): | |
input_a, input_b = random.randint(0, max_value), random.randint(0, max_value) | |
target = input_a + input_b | |
input_a_seq, input_b_seq, target_seq = map(lambda v: [int(x) for x in str(v)], (input_a, input_b, target)) | |
# splice input | |
batch_input[i] = torch.cat((pad_seq(input_a_seq).unsqueeze(0), pad_seq(input_b_seq).unsqueeze(0)), dim=0).transpose(0, 1) | |
# target | |
batch_target[i] = pad_seq(target_seq) | |
return batch_input, batch_target | |
def pad_seq(src_list): | |
src_list.reverse() | |
padded = src_list + [0] * (max_length - len(src_list)) | |
return Variable(torch.LongTensor(padded)) | |
if __name__ == '__main__': | |
train() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment