Created
July 17, 2019 18:06
-
-
Save Hanrui-Wang/12cb4d05507817edeb9e14e829cc6cd8 to your computer and use it in GitHub Desktop.
how to handle the input sequence of seq2seq model
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| class Voc: | |
| def __init__(self, name): | |
| self.name = name | |
| self.trimmed = False | |
| self.word2index = {} | |
| self.word2count = {} | |
| self.index2word = {PAD_token: "PAD", SOS_token: "SOS", EOS_token: "EOS"} | |
| self.num_words = 3 # Count SOS, EOS, PAD | |
| def addSentence(self, sentence): | |
| for word in sentence.split(' '): | |
| self.addWord(word) | |
| def addWord(self, word): | |
| if word not in self.word2index: | |
| self.word2index[word] = self.num_words | |
| self.word2count[word] = 1 | |
| self.index2word[self.num_words] = word | |
| self.num_words += 1 | |
| else: | |
| self.word2count[word] += 1 | |
| # Remove words below a certain count threshold | |
| def trim(self, min_count): | |
| if self.trimmed: | |
| return | |
| self.trimmed = True | |
| keep_words = [] | |
| for k, v in self.word2count.items(): | |
| if v >= min_count: | |
| keep_words.append(k) | |
| print('keep_words {} / {} = {:.4f}'.format( | |
| len(keep_words), len(self.word2index), len(keep_words) / len(self.word2index) | |
| )) | |
| # Reinitialize dictionaries | |
| self.word2index = {} | |
| self.word2count = {} | |
| self.index2word = {PAD_token: "PAD", SOS_token: "SOS", EOS_token: "EOS"} | |
| self.num_words = 3 # Count default tokens | |
| for word in keep_words: | |
| self.addWord(word) | |
| # Lowercase and remove non-letter characters | |
| def normalizeString(s): | |
| s = s.lower() | |
| s = re.sub(r"([.!?])", r" \1", s) | |
| s = re.sub(r"[^a-zA-Z.!?]+", r" ", s) | |
| return s | |
| # Takes string sentence, returns sentence of word indexes | |
| def indexesFromSentence(voc, sentence): | |
| return [voc.word2index[word] for word in sentence.split(' ')] + [EOS_token] | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment