This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def init_weights(self): | |
""" | |
Here we reproduce Keras default initialization weights to initialize Embeddings/LSTM weights | |
""" | |
ih = (param.data for name, param in self.named_parameters() if 'weight_ih' in name) | |
hh = (param.data for name, param in self.named_parameters() if 'weight_hh' in name) | |
b = (param.data for name, param in self.named_parameters() if 'bias' in name) | |
nn.init.uniform(self.embed.weight.data, a=-0.5, b=0.5) | |
for t in ih: | |
nn.init.xavier_uniform(t) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class DeepMojiBatchSampler(object): | |
"""A Batch sampler that enables larger epochs on small datasets and | |
has upsampling functionality. | |
# Arguments: | |
y_in: Labels of the dataset. | |
batch_size: Batch size. | |
epoch_size: Number of samples in an epoch. | |
upsample: Whether upsampling should be done. This flag should only be | |
set on binary class problems. | |
seed: Random number generator seed. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class DeepMojiDataset(Dataset): | |
""" A simple Dataset class. | |
# Arguments: | |
X_in: Inputs of the given dataset. | |
y_in: Outputs of the given dataset. | |
# __getitem__ output: | |
(torch.LongTensor, torch.LongTensor) | |
""" | |
def __init__(self, X_in, y_in): |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# input_seqs is a batch of input sequences as a numpy array of integers (word indices in vocabulary) padded with zeroas | |
input_seqs = Variable(torch.from_numpy(input_seqs.astype('int64')).long()) | |
# First: order the batch by decreasing sequence length | |
input_lengths = torch.LongTensor([torch.max(input_seqs[i, :].data.nonzero()) + 1 for i in range(input_seqs.size()[0])]) | |
input_lengths, perm_idx = input_lengths.sort(0, descending=True) | |
input_seqs = input_seqs[perm_idx][:, :input_lengths.max()] | |
# Then pack the sequences | |
packed_input = pack_padded_sequence(input_seqs, input_lengths.cpu().numpy(), batch_first=True) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class AttentionWeightedAverage(Layer): | |
""" | |
Computes a weighted average of the different channels across timesteps. | |
Uses 1 parameter pr. channel to compute the attention value for a single timestep. | |
""" | |
def __init__(self, return_attention=False, **kwargs): | |
self.init = initializers.get('uniform') | |
self.supports_masking = True | |
self.return_attention = return_attention |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class Attention(Module): | |
""" | |
Computes a weighted average of channels across timesteps (1 parameter pr. channel). | |
""" | |
def __init__(self, attention_size, return_attention=False): | |
""" Initialize the attention layer | |
# Arguments: | |
attention_size: Size of the attention vector. | |
return_attention: If true, output will include the weight for each input token | |
used for the prediction |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def LSTMCell(input, hidden, w_ih, w_hh, b_ih=None, b_hh=None): | |
""" | |
A modified LSTM cell with hard sigmoid activation on the input, forget and output gates. | |
""" | |
hx, cx = hidden | |
gates = F.linear(input, w_ih, b_ih) + F.linear(hx, w_hh, b_hh) | |
ingate, forgetgate, cellgate, outgate = gates.chunk(4, 1) | |
ingate = hard_sigmoid(ingate) |
NewerOlder