Skip to content

Instantly share code, notes, and snippets.

@airalcorn2
Last active March 21, 2020 14:10
Show Gist options
  • Save airalcorn2/81e46c287842ecb7f46110ad8128c275 to your computer and use it in GitHub Desktop.
Save airalcorn2/81e46c287842ecb7f46110ad8128c275 to your computer and use it in GitHub Desktop.
A barebones PyTorch implementation of a seq2seq model with attention.
# Michael A. Alcorn
import torch
import torch.autograd as autograd
import torch.nn as nn
def create_lstm(params):
"""Create a LSTM from a dictionary of parameters.
:param params:
:return:
"""
return nn.LSTM(**params)
def create_h_0_c_0(params):
"""Create variables containing LSTM initial hidden state.
:param params:
:return:
"""
num_directions = 2 if params["bidirectional"] else 1
l_by_d = params["num_layers"] * num_directions
params["num_directions"] = num_directions
params["l_by_d"] = l_by_d
hidden_size = params["hidden_size"]
h_0_var = autograd.Variable(torch.randn(l_by_d, hidden_size), requires_grad = True)
c_0_var = autograd.Variable(torch.randn(l_by_d, hidden_size), requires_grad = True)
return (h_0_var, c_0_var)
# Define the size of the input at each step for the encoder and decoder.
input_size = {"e": 6, "d": 4}
# Create the encoder.
e_d = {"e": {"input_size": input_size["e"],
"hidden_size": 5,
"num_layers": 3,
"batch_first": True,
"bidirectional": True}}
encoder = create_lstm(e_d["e"])
# Create the initial hidden state variables for the encoder.
h_0 = {}
c_0 = {}
(h_0["e"], c_0["e"]) = create_h_0_c_0(e_d["e"])
# Create the decoder.
input_size_d = input_size["d"] + e_d["e"]["num_directions"] * e_d["e"]["hidden_size"]
e_d["d"] = {"input_size": input_size_d,
"hidden_size": 9,
"num_layers": 2,
"batch_first": True,
"bidirectional": False}
decoder = create_lstm(e_d["d"])
# Create initial hidden state variables for the decoder.
(h_0["d"], c_0["d"]) = create_h_0_c_0(e_d["d"])
# Create the attention mechanism.
attn = nn.Linear(e_d["e"]["num_directions"] * e_d["e"]["hidden_size"] + e_d["d"]["hidden_size"], 1)
attn_weights = nn.Softmax(dim = 1)
# Create dummy input and output sequences.
seq_lens = {"e": 7, "d": 8}
seqs = {}
for (e_or_d, seq_len) in seq_lens.items():
x = [autograd.Variable(torch.randn((1, input_size[e_or_d]))) for _ in range(seq_len)]
seqs[e_or_d] = torch.cat(x).view(1, len(x), input_size[e_or_d])
# Calculate encoder outputs.
(out, (h_e, c_e)) = encoder(seqs["e"], (h_0["e"].unsqueeze(1), c_0["e"].unsqueeze(1)))
# Calculate hidden states for decoder using attention mechanism.
(h_t, c_t) = (h_0["d"].unsqueeze(1), c_0["d"].unsqueeze(1))
num_directions_d = e_d["d"]["num_directions"]
hidden_size_d = e_d["d"]["hidden_size"]
input_size_d = input_size["d"]
for i in range(seqs["d"].size(1)):
# h[0] is the output of the bottom layer.
h_att = h_t[0].squeeze(1).unsqueeze(0)
h_ex = h_att.expand(1, seq_lens["e"], hidden_size_d)
concat = torch.cat((out, h_ex), 2)
att = attn(concat)
att_w = attn_weights(att)
attn_applied = torch.bmm(att_w.view(1, 1, seq_lens["e"]),
out)
new_input = torch.cat((seqs["d"][0, i].view(1, 1, input_size_d), attn_applied), dim = 2)
(out_t, (h_t, c_t)) = decoder(new_input, (h_t, c_t))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment