This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Sample title from the trained model | |
def sample(): | |
num_words = 10 | |
# Initialize input step and hidden state | |
input = torch.zeros(1, 1, vocab_size) | |
hidden = (torch.zeros(1, 1, n_hidden).to(device), torch.zeros(1, 1, n_hidden).to(device)) | |
i = 0 | |
output_word = None |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import time | |
import numpy as np | |
import math | |
import matplotlib.pyplot as plt | |
# Set up the number of iterations, printing and plotting options | |
n_iters = 1100000 | |
print_every = 1000 | |
plot_every = 1000 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Define training procedure | |
def train(sequence, target, device): | |
# Move tensors to device | |
hidden = rnn.initHidden(device) | |
sequence = sequence.to(device) | |
target = target.to(device) | |
rnn.zero_grad() | |
# Forward step |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# setup learning rate and loss function | |
learning_rate = 0.005 | |
criterion = nn.NLLLoss() | |
# device to use (GPU if available, CPU otherwise) | |
device = torch.device('cuda:0' if torch.cuda.is_available() else "cpu") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Define a function to convert tensor into index in vocabulary | |
def indexFromTensor(target): | |
''' | |
Function returns tensor containing target index given tensor representing target word | |
''' | |
top_n, top_i = target.topk(1) | |
target_index = top_i[0].item() | |
target_index_tensor = torch.zeros((1), dtype = torch.long) | |
target_index_tensor[0] = target_index |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Define a function which converts output into word | |
def wordFromOutput(output): | |
''' | |
Functions returns an index from the vocabulary and the corresponding word | |
''' | |
top_n, top_i = output.topk(1) | |
category_i = top_i[0].item() | |
return [k for (k, v) in vocab.items() if v == category_i], category_i |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# import PyTorch | |
import torch | |
import torch.nn as nn | |
# Create LSTM | |
class SimpleLSTM(nn.Module): | |
''' | |
Simple LSTM model to generate kernel titles. | |
Arguments: | |
- input_size - should be equal to the vocabulary size |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Generate sequences out of titles: | |
# Define sequence length | |
sequence_length = 3 | |
# Generate sequences | |
def generate_sequences(titles): | |
sequences = [] | |
targets = [] | |
# Loop for all selected titles |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
# Translate word to an index from vocabulary | |
def wordToIndex(word): | |
if (word != end_of_sentence): | |
word = clean_title(word) | |
return vocab[word] | |
# Translate word to 1-hot tensor | |
def wordToTensor(word): |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def create_vocabulary(titles): | |
''' | |
Function to create a vocabulary out of a list of titles | |
''' | |
vocab = set() | |
for title in titles: | |
if (clean_title(title) != ''): | |
words = extract_words(title) | |
vocab.update(words) |
NewerOlder