Skip to content

Instantly share code, notes, and snippets.

@omarsar
Last active February 27, 2020 16:02
Show Gist options
  • Save omarsar/59475857b5ab8a487a800641c70257f8 to your computer and use it in GitHub Desktop.
Save omarsar/59475857b5ab8a487a800641c70257f8 to your computer and use it in GitHub Desktop.
class CleanBasicRNN(nn.Module):
def __init__(self, batch_size, n_inputs, n_neurons):
super(CleanBasicRNN, self).__init__()
self.rnn = nn.RNNCell(n_inputs, n_neurons)
self.hx = torch.randn(batch_size, n_neurons) # initialize hidden state
def forward(self, X):
output = []
# for each time step
for i in range(2):
self.hx = self.rnn(X[i], self.hx)
output.append(self.hx)
return output, self.hx
FIXED_BATCH_SIZE = 4 # our batch size is fixed for now
N_INPUT = 3
N_NEURONS = 5
X_batch = torch.tensor([[[0,1,2], [3,4,5],
[6,7,8], [9,0,1]],
[[9,8,7], [0,0,0],
[6,5,4], [3,2,1]]
], dtype = torch.float) # X0 and X1
model = CleanBasicRNN(FIXED_BATCH_SIZE, N_INPUT, N_NEURONS)
output_val, states_val = model(X_batch)
print(output_val) # contains all output for all timesteps
print(states_val) # contains values for final state or final timestep, i.e., t=1
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment