Skip to content

Instantly share code, notes, and snippets.

@omarsar
Last active February 27, 2020 16:05
Show Gist options
  • Save omarsar/51c9c01557b8c4de4ea33c8bf27fa3ee to your computer and use it in GitHub Desktop.
Save omarsar/51c9c01557b8c4de4ea33c8bf27fa3ee to your computer and use it in GitHub Desktop.
class BasicRNN(nn.Module):
def __init__(self, n_inputs, n_neurons):
super(BasicRNN, self).__init__()
self.Wx = torch.randn(n_inputs, n_neurons) # n_inputs X n_neurons
self.Wy = torch.randn(n_neurons, n_neurons) # n_neurons X n_neurons
self.b = torch.zeros(1, n_neurons) # 1 X n_neurons
def forward(self, X0, X1):
self.Y0 = torch.tanh(torch.mm(X0, self.Wx) + self.b) # batch_size X n_neurons
self.Y1 = torch.tanh(torch.mm(self.Y0, self.Wy) +
torch.mm(X1, self.Wx) + self.b) # batch_size X n_neurons
return self.Y0, self.Y1
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment