Created
May 18, 2018 01:05
-
-
Save evanthebouncy/b5039dc72d3d9fea66dad3306e479e6b to your computer and use it in GitHub Desktop.
trying to learn pytorch rnn by building a contextualized sequence gneerator
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import cv2 | |
import torch | |
from torch import nn | |
import torch.nn.functional as F | |
import random | |
seq_length = 10 | |
token_size = 4 | |
hidden_size = 20 | |
n_layer = 2 | |
context_size = 12 | |
''' | |
a simple sequence generator from a context | |
the context is a number between 0 and 11 | |
the sequence is a sequence of 10 numbers that are different depending on context | |
each number in sequence is 1 or 2 or 3. 0 is a special <start> token | |
example: | |
6 => [1, 1, 1, 1, 1, 1, 1, 1, 1, 1] | |
10 => [1, 2, 3, 1, 2, 3, 1, 2, 3, 1] | |
5 => [1, 3, 2, 1, 3, 2, 1, 3, 2, 1] | |
the context is encoded as 1 hot | |
the sequence also encoded as 1 hot | |
''' | |
def gen_data(): | |
cxt = random.randint(0, context_size-1) | |
sequence = [] | |
for i in range(seq_length): | |
sequence.append( 1 + ((i * cxt) % (token_size-1)) ) | |
return cxt, [0] + sequence[1:], sequence | |
def ctx_to_torch(ctx): | |
ctx_torch = torch.zeros(1, context_size) | |
ctx_torch[0][ctx] = 1.0 | |
return ctx_torch | |
def seq_to_torch(seq): | |
seq_torch = torch.zeros(1, len(seq), token_size) | |
for i, s in enumerate(seq): | |
seq_torch[0][i][s] = 1.0 | |
return seq_torch | |
# takes in a context and generate a sequence form that context | |
class ContextRNN(nn.Module): | |
def __init__(self, context_size, token_size, hidden_size, n_layer): | |
super(ContextRNN, self).__init__() | |
# some constants | |
self.context_size = context_size | |
self.token_size = token_size | |
self.hidden_size = hidden_size | |
self.n_layer = n_layer | |
self.context_encoder = nn.Linear(context_size, n_layer * hidden_size) | |
self.gru = nn.GRU(token_size, hidden_size, n_layer) | |
self.output_decoder = nn.Linear(hidden_size, token_size) | |
self.optim = torch.optim.RMSprop(self.parameters()) | |
def forward(self, context, input): | |
''' | |
context : [batch, context_size] | |
input : [seq_length, batch, token_size] | |
''' | |
# need to transpose the input to [seq_l, batch, tk_size] | |
input = input.transpose(0,1) | |
# encode the context as the first hidden value | |
# [batch, context_size] ==encode==> [n_layer, batch, hidden_size] | |
encoded_context = F.relu(self.context_encoder(context)) | |
h0 = encoded_context.view(-1, self.n_layer, self.hidden_size) | |
h0 = h0.transpose(0,1) | |
# now we use a gru that outputs [seq_l, batch, hidden_size] | |
gru_outputs, hn = self.gru(input, h0) | |
# deocde output back to sequence [seq_l, batch, tk_size] by dec then softmax | |
outputs = self.output_decoder(gru_outputs) | |
outputs = nn.Softmax(dim=2)(outputs) | |
return outputs | |
def roll_out(self, context): | |
# encode the context as the first hidden value | |
# [batch, context_size] ==encode==> [n_layer, batch, hidden_size] | |
encoded_context = F.relu(self.context_encoder(context)) | |
h0 = encoded_context.view(-1, self.n_layer, self.hidden_size) | |
h0 = h0.transpose(0,1) | |
# now we use a gru that outputs [seq_l, batch, hidden_size] | |
gru_outputs, hn = self.gru(input, h0) | |
# deocde output back to sequence [seq_l, batch, tk_size] by dec then softmax | |
outputs = self.output_decoder(gru_outputs) | |
outputs = nn.Softmax(dim=2)(outputs) | |
return outputs | |
def cost(self, outputs, output_targets): | |
seq_l = outputs.size()[0] | |
# swap some dimensions | |
output_targets = output_targets.transpose(0,1) | |
loss = nn.BCELoss() | |
cost = 0 | |
for i in range(seq_l): | |
output, output_target = outputs[i], output_targets[i] | |
cost_i = loss(output, output_target) | |
cost += cost_i | |
return cost | |
def learn(self, context, in_seq, out_seq): | |
out_seq_pred = self(context, in_seq) | |
cost = self.cost(out_seq_pred, out_seq) | |
self.optim.zero_grad() | |
cost.backward() | |
self.optim.step() | |
return cost | |
rnn = ContextRNN(context_size, token_size, hidden_size, n_layer) | |
for i in range(100): | |
ctx, in_seq, out_seq = gen_data() | |
ctx_torch = ctx_to_torch(ctx) | |
seq_in_torch = seq_to_torch(in_seq) | |
seq_out_torch = seq_to_torch(out_seq) | |
cost = rnn.learn(ctx_torch, seq_in_torch, seq_out_torch) | |
print (i, cost) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment