This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def LSTMCell(input, hidden, w_ih, w_hh, b_ih=None, b_hh=None): | |
""" | |
A modified LSTM cell with hard sigmoid activation on the input, forget and output gates. | |
""" | |
hx, cx = hidden | |
gates = F.linear(input, w_ih, b_ih) + F.linear(hx, w_hh, b_hh) | |
ingate, forgetgate, cellgate, outgate = gates.chunk(4, 1) | |
ingate = hard_sigmoid(ingate) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class Attention(Module): | |
""" | |
Computes a weighted average of channels across timesteps (1 parameter pr. channel). | |
""" | |
def __init__(self, attention_size, return_attention=False): | |
""" Initialize the attention layer | |
# Arguments: | |
attention_size: Size of the attention vector. | |
return_attention: If true, output will include the weight for each input token | |
used for the prediction |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class AttentionWeightedAverage(Layer): | |
""" | |
Computes a weighted average of the different channels across timesteps. | |
Uses 1 parameter pr. channel to compute the attention value for a single timestep. | |
""" | |
def __init__(self, return_attention=False, **kwargs): | |
self.init = initializers.get('uniform') | |
self.supports_masking = True | |
self.return_attention = return_attention |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# input_seqs is a batch of input sequences as a numpy array of integers (word indices in vocabulary) padded with zeroas | |
input_seqs = Variable(torch.from_numpy(input_seqs.astype('int64')).long()) | |
# First: order the batch by decreasing sequence length | |
input_lengths = torch.LongTensor([torch.max(input_seqs[i, :].data.nonzero()) + 1 for i in range(input_seqs.size()[0])]) | |
input_lengths, perm_idx = input_lengths.sort(0, descending=True) | |
input_seqs = input_seqs[perm_idx][:, :input_lengths.max()] | |
# Then pack the sequences | |
packed_input = pack_padded_sequence(input_seqs, input_lengths.cpu().numpy(), batch_first=True) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class DeepMojiDataset(Dataset): | |
""" A simple Dataset class. | |
# Arguments: | |
X_in: Inputs of the given dataset. | |
y_in: Outputs of the given dataset. | |
# __getitem__ output: | |
(torch.LongTensor, torch.LongTensor) | |
""" | |
def __init__(self, X_in, y_in): |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class DeepMojiBatchSampler(object): | |
"""A Batch sampler that enables larger epochs on small datasets and | |
has upsampling functionality. | |
# Arguments: | |
y_in: Labels of the dataset. | |
batch_size: Batch size. | |
epoch_size: Number of samples in an epoch. | |
upsample: Whether upsampling should be done. This flag should only be | |
set on binary class problems. | |
seed: Random number generator seed. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def init_weights(self): | |
""" | |
Here we reproduce Keras default initialization weights to initialize Embeddings/LSTM weights | |
""" | |
ih = (param.data for name, param in self.named_parameters() if 'weight_ih' in name) | |
hh = (param.data for name, param in self.named_parameters() if 'weight_hh' in name) | |
b = (param.data for name, param in self.named_parameters() if 'bias' in name) | |
nn.init.uniform(self.embed.weight.data, a=-0.5, b=0.5) | |
for t in ih: | |
nn.init.xavier_uniform(t) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Drawn from https://gist.github.com/rocknrollnerd/c5af642cf217971d93f499e8f70fcb72 (in Theano) | |
# This is implemented in PyTorch | |
# Author : Anirudh Vemula | |
import torch | |
import torch.nn as nn | |
from torch.autograd import Variable | |
import numpy as np | |
from sklearn.datasets import fetch_mldata |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class Model(nn.Module): | |
def __init__(self, vocab_size, embed_dim, H1, H2, H3, pairs_in, single_in, drop=0.5): | |
super(Model, self).__init__() | |
self.embed = nn.Embedding(vocab_size, embedding_dim) | |
self.drop = nn.Dropout(drop) | |
self.pairs = nn.Sequential(nn.Linear(pairs_in, H1), nn.ReLU(), nn.Dropout(drop), | |
nn.Linear(H1, H2), nn.ReLU(), nn.Dropout(drop), | |
nn.Linear(H2, H3), nn.ReLU(), nn.Dropout(drop), | |
nn.Linear(H3, 1), | |
nn.Linear(1, 1)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def get_params(module, memo=None, pointers=None): | |
""" Returns an iterator over PyTorch module parameters that allows to update parameters | |
(and not only the data). | |
! Side effect: update shared parameters to point to the first yield instance | |
(i.e. you can update shared parameters and keep them shared) | |
Yields: | |
(Module, string, Parameter): Tuple containing the parameter's module, name and pointer | |
""" | |
if memo is None: | |
memo = set() |
OlderNewer