Skip to content

Instantly share code, notes, and snippets.

@omarsar
Created August 19, 2018 03:10
Show Gist options
  • Save omarsar/072338aafab36eb3cb1bbd24212a2bc3 to your computer and use it in GitHub Desktop.
Save omarsar/072338aafab36eb3cb1bbd24212a2bc3 to your computer and use it in GitHub Desktop.
class ImageRNN(nn.Module):
def __init__(self, batch_size, n_steps, n_inputs, n_neurons, n_outputs):
super(ImageRNN, self).__init__()
self.n_neurons = n_neurons
self.batch_size = batch_size
self.n_steps = n_steps
self.n_inputs = n_inputs
self.n_outputs = n_outputs
self.basic_rnn = nn.RNN(self.n_inputs, self.n_neurons)
self.FC = nn.Linear(self.n_neurons, self.n_outputs)
def init_hidden(self,):
# (num_layers, batch_size, n_neurons)
return (torch.zeros(1, self.batch_size, self.n_neurons))
def forward(self, X):
# transforms X to dimensions: n_steps X batch_size X n_inputs
X = X.permute(1, 0, 2)
self.batch_size = X.size(1)
self.hidden = self.init_hidden()
lstm_out, self.hidden = self.basic_rnn(X, self.hidden)
out = self.FC(self.hidden)
return out.view(-1, self.n_outputs) # batch_size X n_output
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment