Skip to content

Instantly share code, notes, and snippets.

@mbednarski
Created November 8, 2018 18:18
Show Gist options
  • Save mbednarski/89911b9a6e31a16bb6f9ba1a3789841d to your computer and use it in GitHub Desktop.
Save mbednarski/89911b9a6e31a16bb6f9ba1a3789841d to your computer and use it in GitHub Desktop.
import os
from sklearn.metrics import classification_report
from mlp import DumbModel, Dataset
def train_model(dataset_dir, model_file, vocab_size):
print(f'Training model from directory {dataset_dir}')
print(f'Vocabulary size: {vocab_size}')
train_dir = os.path.join(dataset_dir, 'train')
test_dir = os.path.join(dataset_dir, 'test')
dset = Dataset(train_dir, test_dir)
X, y = dset.get_train_set()
model = DumbModel(vocab_size=vocab_size)
model.train(X, y)
print(f'Storing model to {model_file}')
model.serialize(model_file)
X_test, y_test = dset.get_test_set()
y_pred = model.predict(X_test)
print(classification_report(y_test, y_pred))
def ask_model(model_file, question):
print(f'Asking model {model_file} about "{question}"')
model = DumbModel.deserialize(model_file)
y_pred = model.predict_proba([question])
print(y_pred[0])
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment