Created
December 26, 2021 06:43
-
-
Save ThunderWiring/86440097ca48a993d17b359dcb3a40e7 to your computer and use it in GitHub Desktop.
lstm network implementation
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import math | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from data.dataset_loader import device | |
use_cuda = torch.cuda.is_available() | |
device = torch.device("cuda" if use_cuda else "cpu") | |
torch.backends.cudnn.benchmark = True | |
class LinearLayer(nn.Linear): | |
def forward(self, input): | |
# the .clone() here is important and without it a runtime exception will be thrown while | |
# back-propagating, due to changing the params in-place. | |
return F.linear(input.clone(), self.weight.clone(), self.bias.clone()) | |
class ForgetGate(nn.Module): | |
def __init__(self, in_size, forget_gate_hidden_size, out_size): | |
super().__init__() | |
self.hidden_layer = LinearLayer( | |
in_size, forget_gate_hidden_size, device=device) | |
self.out_layer = LinearLayer( | |
forget_gate_hidden_size, out_size, device=device) | |
self.linear = LinearLayer(in_size, out_size, device=device) | |
def forward(self, input): | |
return torch.sigmoid(self.linear(input)) | |
class InputGate(nn.Module): | |
def __init__(self, in_size, forget_gate_hidden_size, input_gate_hidden_size, out_size): | |
super().__init__() | |
self.forget_gate = ForgetGate( | |
in_size, forget_gate_hidden_size, out_size) | |
self.tanh = LinearLayer(in_size, out_size, device=device) | |
def forward(self, input): | |
pass_through = self.forget_gate(input) | |
inter_state = torch.tanh(self.tanh(input)) | |
return pass_through * inter_state | |
class LSTM(nn.Module): | |
def __init__(self, in_size, state_size, out_size, forget_gate_hidden_size, input_gate_hidden_size): | |
super().__init__() | |
self.state_size = state_size | |
self.out_size = out_size | |
self.forget_gate = ForgetGate( | |
in_size+out_size, forget_gate_hidden_size, state_size) | |
self.input_gate = InputGate( | |
in_size+out_size, forget_gate_hidden_size, input_gate_hidden_size, state_size) | |
self.output_gate = ForgetGate( | |
in_size+out_size, forget_gate_hidden_size, state_size) | |
def forward(self, input, prev_output, prev_cell_state): | |
all_input = torch.cat((input, prev_output), 1).detach() | |
pass_through = self.forget_gate(all_input).detach() | |
input_g = self.input_gate(all_input).detach() | |
output_g = self.output_gate(all_input) | |
current_state = pass_through * prev_cell_state + input_g | |
output = torch.tanh(current_state) * output_g | |
return current_state.detach(), output | |
def get_init_state(self): | |
return torch.zeros(1, self.state_size).detach().to(device) | |
def get_init_out(self): | |
return torch.zeros((1, self.out_size)).detach().to(device) | |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
use_cuda = torch.cuda.is_available() | |
device = torch.device("cuda" if use_cuda else "cpu") | |
torch.backends.cudnn.benchmark = True | |
class LSTM(nn.Module): | |
def __init__(self, in_size, state_size, out_size, forget_gate_hidden_size, input_gate_hidden_size): | |
super().__init__() | |
self.state_size = state_size | |
self.out_size = out_size | |
self.prev_state = torch.Tensor(state_size, 1).to(device) | |
# params for the forget-gate (denoted f_t) | |
self.W_in_f = nn.Parameter(torch.Tensor(state_size, in_size)) | |
self.W_h_f = nn.Parameter(torch.Tensor(state_size, out_size)) | |
self.b_f = nn.Parameter(torch.Tensor(state_size, 1)) | |
# input gate (i_t and ~C_t) | |
self.W_in_i = nn.Parameter(torch.Tensor(state_size, in_size)) | |
self.W_h_i = nn.Parameter(torch.Tensor(state_size, out_size)) | |
self.b_it_i = nn.Parameter(torch.Tensor(state_size, 1)) | |
self.W_C_i = nn.Parameter(torch.Tensor(state_size, in_size)) | |
self.W_C_h_i = nn.Parameter(torch.Tensor(state_size, out_size)) | |
self.b_C_i = nn.Parameter(torch.Tensor(state_size, 1)) | |
# out-gate: o_t | |
self.W_in_ot = nn.Parameter(torch.Tensor(state_size, in_size)) | |
self.W_h_ot = nn.Parameter(torch.Tensor(state_size, out_size)) | |
self.b_ot = nn.Parameter(torch.Tensor(state_size, 1)) | |
self.init_params() | |
def init_params(self): | |
stdv = 1.0 / math.sqrt(self.state_size) | |
for weight in self.parameters(): | |
weight.data.uniform_(-stdv, stdv) | |
def forward(self, input, prev_out, _): | |
in_detached = input.detach() | |
f_t = torch.sigmoid(self.W_in_f @ in_detached + | |
self.W_h_f.clone() @ prev_out + self.b_f) | |
i_t = torch.sigmoid(self.W_in_i @ in_detached + | |
self.W_h_i.clone() @ prev_out + self.b_it_i) | |
C_intr = torch.tanh(self.W_C_i @ in_detached + | |
self.W_C_h_i.clone() @ prev_out + self.b_C_i) | |
o_t = torch.sigmoid(self.W_in_ot @ in_detached + | |
self.W_h_ot.clone() @ prev_out + self.b_ot) | |
self.prev_state = self.prev_state * f_t + (i_t * C_intr) | |
out = torch.tanh(self.prev_state) * o_t | |
return self.prev_state.detach(), out | |
def get_init_state(self): | |
return torch.zeros(self.state_size, 1).detach().to(device) | |
def get_init_out(self): | |
return torch.zeros((self.out_size, 1)).detach().to(device) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Followed this tutorial to see how the lstm cell work: https://colah.github.io/posts/2015-08-Understanding-LSTMs/
in both implementation i used
.clone()
on some of the tensors (first impl inLinearLayer.forward()
and second impl in theforward()
)The exception im getting without the clones is:
The issue im having with those 2 implemtations is the trainig for each batch is getting longer and longer after each one. Did some searching, and looks like there are params that are carried from one loop to the next and need to be detached from the graph.