Skip to content

Instantly share code, notes, and snippets.

@rubenfiszel
Created December 12, 2016 23:13
Show Gist options
  • Save rubenfiszel/3af3b8dd32269407581693c835b7e61e to your computer and use it in GitHub Desktop.
Save rubenfiszel/3af3b8dd32269407581693c835b7e61e to your computer and use it in GitHub Desktop.
import numpy as np
np.random.seed(1337) # for reproducibility
from keras.preprocessing import sequence
from keras.models import Sequential
import as_keras_ds
import argparse
from keras.models import load_model
import sys
import os
import defaults
from sentiword import seq_to_sentiword
parser = argparse.ArgumentParser(description='Process test data from model')
parser.add_argument('fname', type=str, nargs=1,
help='name of the model')
parser.add_argument('-full', action='store_true',
help='Use full datasets')
parser.add_argument('-maxlen', type=int, nargs='?',
help='Max length of the sequence')
parser.add_argument('-sw', action='store_true',
help='Use sentiword in embedding')
args = parser.parse_args()
sw = args.sw
maxlen = args.maxlen
if not maxlen:
maxlen = defaults.MAX_LEN
cdir = os.path.dirname(os.path.abspath(__file__)) + '/'
fname = 'tokenizer'
if args.full:
fname += '-full'
fname += '.pkl'
tokenizer = as_keras_ds.load_tokenizer(cdir + fname)
#stdin = sys.stdin.read().splitlines()
#seq_test = as_keras_ds.build_seq(tokenizer, stdin)
#X_test = np.asarray(seq_test)
((X_train, y_train), (X_test, Y_test)), _, _ = as_keras_ds.load_data("fasttest.vec", True, False)
print("Pad sequences (samples x time)")
X_test = sequence.pad_sequences(X_test, maxlen=maxlen)
if sw:
print("Seq to sentiword")
inv_word_index = {v: k for k, v in tokenizer.word_index.items()}
X_test_sentiword = seq_to_sentiword(X_test, inv_word_index)
fname = cdir + 'saved/' + args.fname[0] + '.h5'
print(fname)
model = load_model(fname)
model.evaluate(X_test, Y_test)
if (sw):
predictions = model.predict([X_test, X_test_sentiword])
else:
predictions = model.predict(X_test)
f = open('predictions/keras_predictions', 'w+')
for p in predictions:
value = str(int(round(p[0])))
f.write(value)
f.write("\n")
f.close()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment