Skip to content

Instantly share code, notes, and snippets.

@abhshkdz
Created October 23, 2019 21:40
Show Gist options
  • Save abhshkdz/185f6babd3858fa7c5f0bc986bbca767 to your computer and use it in GitHub Desktop.
Save abhshkdz/185f6babd3858fa7c5f0bc986bbca767 to your computer and use it in GitHub Desktop.
import math
import numpy as np
import torch
import torch.nn as nn
class URLSTMCell(nn.Module):
"""Implementation of the UR-LSTM cell from the paper: Improving the Gating
Mechanism of Recurrent Neural Networks (https://arxiv.org/abs/1910.09890) by
Gu et al., 2019.
"""
def __init__(self, input_size, hidden_size):
super(URLSTMCell, self).__init__()
self.input_size = input_size
self.hidden_size = hidden_size
self.weight_ih = nn.Parameter(torch.randn(4 * hidden_size, input_size))
self.weight_hh = nn.Parameter(torch.randn(4 * hidden_size, hidden_size))
self.init_weights()
u = np.random.uniform(1/hidden_size, 1-1/hidden_size, hidden_size)
self.bias_forgetgate = nn.Parameter(torch.Tensor(-np.log(1/u - 1)))
def init_weights(self):
stdv = 1.0 / math.sqrt(self.hidden_size)
self.weight_ih.data.uniform_(-stdv, stdv)
self.weight_hh.data.uniform_(-stdv, stdv)
def forward(self, input, state):
hx, cx = state
gates = (torch.mm(input, self.weight_ih.t()) +
torch.mm(hx, self.weight_hh.t()))
refinegate, forgetgate, cellgate, outgate = gates.chunk(4, 1)
forgetgate = torch.sigmoid(forgetgate + self.bias_forgetgate)
refinegate = torch.sigmoid(refinegate - self.bias_forgetgate)
g = 2 * refinegate * forgetgate + (1 - 2 * refinegate) * forgetgate ** 2
cellgate = torch.tanh(cellgate)
outgate = torch.sigmoid(outgate)
cy = (g * cx) + ((1-g) * cellgate)
hy = outgate * torch.tanh(cy)
return hy, (hy, cy)
class URLSTM(nn.Module):
def __init__(self, input_size, hidden_size, dropout=0):
super(URLSTM, self).__init__()
self.cell = URLSTMCell(input_size, hidden_size)
def forward(self, input, state):
inputs = input.unbind(0)
outputs = []
for i in range(len(inputs)):
out, state = self.cell(inputs[i], state)
outputs += [out]
return torch.stack(outputs), state
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment