Last active
January 4, 2019 17:16
-
-
Save hamishdickson/72e2e39d3e7694cc4703dbbdff3a7954 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
TL; DR: I got annoyed with the way unknown_word_tokens were dealt with in keras | |
""" | |
pad_token = 0 | |
unknown_word_token = 1 | |
# tune this | |
min_occurances = 3 | |
max_len = 60 | |
class Tokenizer(object): | |
def __init__(self): | |
self.max_index = 1 | |
self.word2index = {} | |
self.word2count = {} | |
self.index2word = {pad_token: "PAD", unknown_word_token: "UNKN"} | |
def word_index(self): | |
return self.word2index | |
def max_words_found(self): | |
return self.max_index | |
def total_words(self): | |
return len(self.word2index) | |
def fit_on_texts(self, texts, min_occ): | |
# first go though everything and add to dict | |
for sentence in tqdm(texts): | |
words = sentence.split(' ') | |
for word in words: | |
if word in self.word2index and word in self.word2count: | |
self.word2count[word] += 1 | |
else: | |
self.max_index += 1 | |
self.word2index[word] = self.max_index | |
self.index2word[self.max_index] = word | |
self.word2count[word] = 1 | |
# remove the words with a count that's too low | |
for w, c in self.word2count.items(): | |
if c < min_occ: | |
del self.index2word[self.word2index[w]] | |
del self.word2index[w] | |
def texts_to_sequences(self, texts, pad_length): | |
out_sentences = [] | |
for sentence in tqdm(texts): | |
words = sentence.split(' ') | |
out = [] | |
for word in reversed(words): | |
if word in self.word2index: | |
out.append(self.word2index[word]) | |
else: | |
out.append(unknown_word_token) | |
while len(out) < pad_length: | |
out.append(pad_token) | |
out_sentences.append(out[::-1]) | |
return np.array(out_sentences) | |
def sequences_to_texts(self, sequences): | |
out_texts = [] | |
for seq in tqdm(sequences): | |
out = [] | |
for num in seq: | |
# if you wanted safety then you should have used haskell | |
out.append(self.index2word[num]) | |
out_texts.append(' '.join(out)) | |
return out_texts |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment