Skip to content

Instantly share code, notes, and snippets.

@lextoumbourou
Last active August 27, 2018 00:27
Show Gist options
  • Save lextoumbourou/35d12809874240aff704c78dbb34c8bf to your computer and use it in GitHub Desktop.
Save lextoumbourou/35d12809874240aff704c78dbb34c8bf to your computer and use it in GitHub Desktop.
Making predictions with a SequentialRNN model (Fast.ai)
# Note: ensure you have the latest version of Torchtext by running: pip install torchtext --upgrade
rnn_model = text_data.get_model(opt_fn, 1500, bptt, emb_sz=em_sz, n_hid=nh, n_layers=nl,
dropout=0.1, dropouti=0.65, wdrop=0.5, dropoute=0.1, dropouth=0.3)
# ...
rnn_model.data.test_dl.src.sort = False
rnn_model.data.test_dl.src.sort_within_batch = False
rnn_model.data.test_dl.src.shuffle = False
probs = rnn_model.predict(is_test=True)
preds = np.argmax(probs, axis=1)
pd.DataFrame({
'id': test_df['index'],
'sentiment': [LABEL_FIELD.vocab.itos[p] for p in preds]}).to_csv('./sub1.csv', index=False)
FileLink('./sub1.csv')
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment