Skip to content

Instantly share code, notes, and snippets.

@evanthebouncy
Created May 18, 2018 01:05
Show Gist options
  • Save evanthebouncy/b5039dc72d3d9fea66dad3306e479e6b to your computer and use it in GitHub Desktop.
Save evanthebouncy/b5039dc72d3d9fea66dad3306e479e6b to your computer and use it in GitHub Desktop.
trying to learn pytorch rnn by building a contextualized sequence gneerator
import cv2
import torch
from torch import nn
import torch.nn.functional as F
import random
seq_length = 10
token_size = 4
hidden_size = 20
n_layer = 2
context_size = 12
'''
a simple sequence generator from a context
the context is a number between 0 and 11
the sequence is a sequence of 10 numbers that are different depending on context
each number in sequence is 1 or 2 or 3. 0 is a special <start> token
example:
6 => [1, 1, 1, 1, 1, 1, 1, 1, 1, 1]
10 => [1, 2, 3, 1, 2, 3, 1, 2, 3, 1]
5 => [1, 3, 2, 1, 3, 2, 1, 3, 2, 1]
the context is encoded as 1 hot
the sequence also encoded as 1 hot
'''
def gen_data():
cxt = random.randint(0, context_size-1)
sequence = []
for i in range(seq_length):
sequence.append( 1 + ((i * cxt) % (token_size-1)) )
return cxt, [0] + sequence[1:], sequence
def ctx_to_torch(ctx):
ctx_torch = torch.zeros(1, context_size)
ctx_torch[0][ctx] = 1.0
return ctx_torch
def seq_to_torch(seq):
seq_torch = torch.zeros(1, len(seq), token_size)
for i, s in enumerate(seq):
seq_torch[0][i][s] = 1.0
return seq_torch
# takes in a context and generate a sequence form that context
class ContextRNN(nn.Module):
def __init__(self, context_size, token_size, hidden_size, n_layer):
super(ContextRNN, self).__init__()
# some constants
self.context_size = context_size
self.token_size = token_size
self.hidden_size = hidden_size
self.n_layer = n_layer
self.context_encoder = nn.Linear(context_size, n_layer * hidden_size)
self.gru = nn.GRU(token_size, hidden_size, n_layer)
self.output_decoder = nn.Linear(hidden_size, token_size)
self.optim = torch.optim.RMSprop(self.parameters())
def forward(self, context, input):
'''
context : [batch, context_size]
input : [seq_length, batch, token_size]
'''
# need to transpose the input to [seq_l, batch, tk_size]
input = input.transpose(0,1)
# encode the context as the first hidden value
# [batch, context_size] ==encode==> [n_layer, batch, hidden_size]
encoded_context = F.relu(self.context_encoder(context))
h0 = encoded_context.view(-1, self.n_layer, self.hidden_size)
h0 = h0.transpose(0,1)
# now we use a gru that outputs [seq_l, batch, hidden_size]
gru_outputs, hn = self.gru(input, h0)
# deocde output back to sequence [seq_l, batch, tk_size] by dec then softmax
outputs = self.output_decoder(gru_outputs)
outputs = nn.Softmax(dim=2)(outputs)
return outputs
def roll_out(self, context):
# encode the context as the first hidden value
# [batch, context_size] ==encode==> [n_layer, batch, hidden_size]
encoded_context = F.relu(self.context_encoder(context))
h0 = encoded_context.view(-1, self.n_layer, self.hidden_size)
h0 = h0.transpose(0,1)
# now we use a gru that outputs [seq_l, batch, hidden_size]
gru_outputs, hn = self.gru(input, h0)
# deocde output back to sequence [seq_l, batch, tk_size] by dec then softmax
outputs = self.output_decoder(gru_outputs)
outputs = nn.Softmax(dim=2)(outputs)
return outputs
def cost(self, outputs, output_targets):
seq_l = outputs.size()[0]
# swap some dimensions
output_targets = output_targets.transpose(0,1)
loss = nn.BCELoss()
cost = 0
for i in range(seq_l):
output, output_target = outputs[i], output_targets[i]
cost_i = loss(output, output_target)
cost += cost_i
return cost
def learn(self, context, in_seq, out_seq):
out_seq_pred = self(context, in_seq)
cost = self.cost(out_seq_pred, out_seq)
self.optim.zero_grad()
cost.backward()
self.optim.step()
return cost
rnn = ContextRNN(context_size, token_size, hidden_size, n_layer)
for i in range(100):
ctx, in_seq, out_seq = gen_data()
ctx_torch = ctx_to_torch(ctx)
seq_in_torch = seq_to_torch(in_seq)
seq_out_torch = seq_to_torch(out_seq)
cost = rnn.learn(ctx_torch, seq_in_torch, seq_out_torch)
print (i, cost)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment