Last active
March 2, 2019 09:39
-
-
Save farizrahman4u/d66e52a4dbc7ebf3b91960ae4688d93c to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
import json | |
from keras.utils import Progbar | |
# dictionary management | |
idx2word = {20000: '<UNK>'} | |
word2idx = {'<UNK>': 20000} | |
frequencies = {} | |
max_words = 20000 | |
max_len = 80 | |
top_word_idxs = None | |
def save_dicts(): | |
dicts = [ | |
idx2word, | |
word2idx, | |
frequencies | |
] | |
with open('dicts.json', 'w') as f: | |
json.dump(dicts, f) | |
def load_dicts(): | |
global idx2word, word2idx, frequencies | |
global top_word_idxs | |
with open('dicts.json', 'r') as f: | |
dicts = json.load(f) | |
idx2word, word2idx, frequencies = dicts | |
all_words = list(frequencies.keys()) | |
all_words.sort(key=lambda x: -frequencies[x]) | |
top_words = all_words[:max_words] | |
top_word_idxs = [get_index(w) for w in top_words] | |
def pad(word_idxs): | |
if len(word_idxs) >= max_len: | |
return word_idxs[:max_len] | |
if len(word_idxs) < max_len: | |
return word_idxs + [max_words] * (max_len - len(word_idxs)) | |
def remove_uncommon_words(input_word_idxs): | |
global top_word_idxs | |
if top_word_idxs is None: | |
all_words = list(frequencies.keys()) | |
all_words.sort(key=lambda x: -frequencies[x]) | |
top_words = all_words[:max_words] | |
top_word_idxs = [get_index(w) for w in top_words] | |
y = [i if i in top_word_idxs else max_words for i in input_word_idxs] | |
return y | |
def get_index(word, add_new=True): | |
if word in word2idx: | |
frequencies[word] += 1 | |
return word2idx[word] | |
if not add_new: | |
return max_words | |
idx = len(idx2word) | |
idx2word[idx] = word | |
word2idx[word] = idx | |
frequencies[word] = 1 | |
return idx | |
def parse_row(line): | |
s = line.split(',') | |
t = s[1] | |
label = s[-2] | |
if label == 'pos': | |
label = 1 | |
elif label == 'neg': | |
label = 0 | |
else: | |
print(label) | |
label = -1 | |
review = s[2:-2] | |
review = ','.join(review) | |
review = review[1:-1] | |
return review, label, t | |
def load_data(path='data/imdb_master.csv'): | |
with open(path) as f: | |
lines = f.readlines() | |
lines.pop(0) | |
lines = [l for l in lines if 'unsup' not in l] | |
#lines = lines[:10] + lines[-10:] | |
x_train = [] | |
y_train = [] | |
x_test = [] | |
y_test = [] | |
pbar = Progbar(len(lines)) | |
for line in lines: | |
pbar.add(1) | |
line = parse_row(line) | |
x = line[0] | |
x = vectorize(x) | |
label = line[1] | |
if line[2] == 'train': | |
x_train.append(x) | |
y_train.append(label) | |
else: | |
x_test.append(x) | |
y_test.append(label) | |
x_train = [remove_uncommon_words(x) for x in x_train] | |
x_test = [remove_uncommon_words(x) for x in x_test] | |
x_train = [pad(x) for x in x_train] | |
x_test = [pad(x) for x in x_test] | |
x_train = np.array(x_train) | |
x_test = np.array(x_test) | |
y_train = np.array(y_train) | |
y_test = np.array(y_test) | |
save_dicts() | |
return (x_train, y_train), (x_test, y_test) | |
def remove_puncts(word): | |
y = '' | |
for c in word: | |
c = c.lower() | |
if (c >= '0' and c <= '9') or (c >= 'a' and c <= 'z'): | |
y += c | |
return y | |
def vectorize(sentance): | |
words = sentance.split(' ') | |
words = list(map(remove_puncts, words)) | |
words = list(filter(lambda w: len(w.replace(' ', ''))> 0, words)) | |
word_idxs = [get_index(word) for word in words] | |
return word_idxs |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment