Skip to content

Instantly share code, notes, and snippets.

@MaximumEntropy
Created October 18, 2017 20:40
Show Gist options
  • Save MaximumEntropy/b91a33e69130aaa3b1b43db7b0cd6dd0 to your computer and use it in GitHub Desktop.
Save MaximumEntropy/b91a33e69130aaa3b1b43db7b0cd6dd0 to your computer and use it in GitHub Desktop.
Peephole GRU
class PeepholeGRU(nn.Module):
"""A Gated Recurrent Unit (GRU) cell with peepholes."""
def __init__(
self, input_dim, hidden_dim, n_layers,
dropout=0., batch_first=True
):
"""Initialize params."""
super(PeepholeGRU, self).__init__()
self.input_dim = input_dim
self.hidden_dim = hidden_dim
self.num_layers = 1
self.input_weights = nn.Linear(self.input_dim, 3 * self.hidden_dim)
self.hidden_weights = nn.Linear(self.hidden_dim, 3 * self.hidden_dim)
self.peep_weights = nn.Linear(self.hidden_dim, 3 * self.hidden_dim)
self.reset_parameters()
def reset_parameters(self):
"""Set params."""
stdv = 1.0 / math.sqrt(self.hidden_dim)
for weight in self.parameters():
weight.data.uniform_(-stdv, stdv)
def forward(self, input, hidden, ctx):
r"""Propogate input through the layer.
inputs:
input - batch size x target sequence length x embedding dimension
hidden - batch size x hidden dimension
ctx - batch size x hidden dimension
returns: output, hidden
output - batch size x target sequence length x hidden dimension
hidden - batch size x hidden dimension
"""
def recurrence(input, hidden, ctx):
"""Recurrence helper."""
input_gate = self.input_weights(input)
hidden_gate = self.hidden_weights(hidden)
peep_gate = self.peep_weights(ctx)
i_r, i_i, i_n = input_gate.chunk(3, 1)
h_r, h_i, h_n = hidden_gate.chunk(3, 1)
p_r, p_i, p_n = peep_gate.chunk(3, 1)
resetgate = F.sigmoid(i_r + h_r + p_r)
inputgate = F.sigmoid(i_i + h_i + p_i)
newgate = F.tanh(i_n + resetgate * h_n + p_n)
hy = newgate + inputgate * (hidden - newgate)
return hy
input = input.transpose(0, 1)
output = []
steps = range(input.size(0))
for i in steps:
hidden = recurrence(input[i], hidden, ctx)
if isinstance(hidden, tuple):
output.append(hidden[0])
else:
output.append(hidden)
output = torch.cat(output, 0).view(input.size(0), *output[0].size())
output = output.transpose(0, 1)
return output, hidden
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment