Last active
August 22, 2019 09:06
-
-
Save yasufumy/ba73b587bd3c516b66fb94b3a90bac71 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
import torch | |
import lineflow as lf | |
class Dictionary(object): | |
def __init__(self): | |
self.word2idx = {} | |
self.idx2word = [] | |
def add_word(self, word): | |
if word not in self.word2idx: | |
self.idx2word.append(word) | |
self.word2idx[word] = len(self.idx2word) - 1 | |
return self.word2idx[word] | |
def __len__(self): | |
return len(self.idx2word) | |
class Corpus(object): | |
def __init__(self, path): | |
self.dictionary = Dictionary() | |
self.train = self.tokenize(os.path.join(path, 'train.txt')) | |
self.valid = self.tokenize(os.path.join(path, 'valid.txt')) | |
self.test = self.tokenize(os.path.join(path, 'test.txt')) | |
def tokenize(self, path): | |
assert os.path.exists(path) | |
dataset = lf.TextDataset(path, encoding='utf-8').map(lambda x: x.split() + ['<eos>']) | |
for word in dataset.flat_map(lambda x: x): | |
self.dictionary.add_word(word) | |
return torch.LongTensor(dataset.flat_map( | |
lambda x: [self.dictionary.word2idx[token] for token in x])) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment