Skip to content

Instantly share code, notes, and snippets.

@farizrahman4u
Last active March 2, 2019 09:39
Show Gist options
  • Save farizrahman4u/d66e52a4dbc7ebf3b91960ae4688d93c to your computer and use it in GitHub Desktop.
Save farizrahman4u/d66e52a4dbc7ebf3b91960ae4688d93c to your computer and use it in GitHub Desktop.
import numpy as np
import json
from keras.utils import Progbar
# dictionary management
idx2word = {20000: '<UNK>'}
word2idx = {'<UNK>': 20000}
frequencies = {}
max_words = 20000
max_len = 80
top_word_idxs = None
def save_dicts():
dicts = [
idx2word,
word2idx,
frequencies
]
with open('dicts.json', 'w') as f:
json.dump(dicts, f)
def load_dicts():
global idx2word, word2idx, frequencies
global top_word_idxs
with open('dicts.json', 'r') as f:
dicts = json.load(f)
idx2word, word2idx, frequencies = dicts
all_words = list(frequencies.keys())
all_words.sort(key=lambda x: -frequencies[x])
top_words = all_words[:max_words]
top_word_idxs = [get_index(w) for w in top_words]
def pad(word_idxs):
if len(word_idxs) >= max_len:
return word_idxs[:max_len]
if len(word_idxs) < max_len:
return word_idxs + [max_words] * (max_len - len(word_idxs))
def remove_uncommon_words(input_word_idxs):
global top_word_idxs
if top_word_idxs is None:
all_words = list(frequencies.keys())
all_words.sort(key=lambda x: -frequencies[x])
top_words = all_words[:max_words]
top_word_idxs = [get_index(w) for w in top_words]
y = [i if i in top_word_idxs else max_words for i in input_word_idxs]
return y
def get_index(word, add_new=True):
if word in word2idx:
frequencies[word] += 1
return word2idx[word]
if not add_new:
return max_words
idx = len(idx2word)
idx2word[idx] = word
word2idx[word] = idx
frequencies[word] = 1
return idx
def parse_row(line):
s = line.split(',')
t = s[1]
label = s[-2]
if label == 'pos':
label = 1
elif label == 'neg':
label = 0
else:
print(label)
label = -1
review = s[2:-2]
review = ','.join(review)
review = review[1:-1]
return review, label, t
def load_data(path='data/imdb_master.csv'):
with open(path) as f:
lines = f.readlines()
lines.pop(0)
lines = [l for l in lines if 'unsup' not in l]
#lines = lines[:10] + lines[-10:]
x_train = []
y_train = []
x_test = []
y_test = []
pbar = Progbar(len(lines))
for line in lines:
pbar.add(1)
line = parse_row(line)
x = line[0]
x = vectorize(x)
label = line[1]
if line[2] == 'train':
x_train.append(x)
y_train.append(label)
else:
x_test.append(x)
y_test.append(label)
x_train = [remove_uncommon_words(x) for x in x_train]
x_test = [remove_uncommon_words(x) for x in x_test]
x_train = [pad(x) for x in x_train]
x_test = [pad(x) for x in x_test]
x_train = np.array(x_train)
x_test = np.array(x_test)
y_train = np.array(y_train)
y_test = np.array(y_test)
save_dicts()
return (x_train, y_train), (x_test, y_test)
def remove_puncts(word):
y = ''
for c in word:
c = c.lower()
if (c >= '0' and c <= '9') or (c >= 'a' and c <= 'z'):
y += c
return y
def vectorize(sentance):
words = sentance.split(' ')
words = list(map(remove_puncts, words))
words = list(filter(lambda w: len(w.replace(' ', ''))> 0, words))
word_idxs = [get_index(word) for word in words]
return word_idxs
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment