Created
January 17, 2018 10:15
-
-
Save moritzschaefer/70aa5527fe64d746bf36044f43a45564 to your computer and use it in GitHub Desktop.
Combination of LSTM and ConvLayer
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from torch import nn | |
import torch | |
from torch.nn.init import kaiming_normal, normal | |
def weights_init(m): | |
if isinstance(m, (nn.Conv1d, nn.Linear)): | |
kaiming_normal(m.weight.data) | |
try: | |
kaiming_normal(m.bias.data) | |
except ValueError: | |
normal(m.bias.data) | |
class Deep1(nn.Module): | |
''' | |
A combination of an LSTM and Conv layers | |
4 is the nucleotide encoding (4 bit per nucleotide) | |
30 is the length of our input sequence | |
120 is 4*30 the number of sequence features | |
input_size is the number of total input_features. | |
The first 120 have to be the sequence 1-hot-encodings | |
''' | |
lstm_hidden = 50 | |
kernel_size = 4 | |
def __init__(self, input_size): | |
super(Deep1, self).__init__() | |
self.lstm = nn.LSTM(input_size=4, hidden_size=self.lstm_hidden, num_layers=2, | |
dropout=False, bidirectional=False) # TODO enable? | |
self.conv1 = nn.Conv1d( | |
in_channels=4, out_channels=4, kernel_size=self.kernel_size) | |
# hidden layers, additional_features, conv output | |
self.fc1 = nn.Linear( | |
self.lstm_hidden + (input_size - 120) + 4 * (30 - self.kernel_size + 1), 1) | |
self.apply(weights_init) | |
def forward(self, x): | |
nuc_features, additional_features = x.split(120, dim=1) | |
nuc_features.contiguous() | |
# lstm needs form (seq_len, batch, input_size) | |
lstm_input = nuc_features.view(-1, 30, 4).permute(1, 0, 2) | |
# return only last seq-output. Form: (batch_size x lstm_hidden) | |
lstm_output = self.lstm(lstm_input)[0][-1, :, :] | |
# batch_size x 4 x 27 (30-kernel_size+1) | |
conv1_output = self.conv1( | |
nuc_features.view(-1, 30, 4).permute(0, 2, 1)) | |
# TODO add max-pooling | |
conv1_output = conv1_output.view(-1, 4 * (30 - self.kernel_size + 1)) | |
return self.fc1(torch.cat([lstm_output, additional_features, conv1_output], 1)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment