Created
November 8, 2018 18:18
-
-
Save mbednarski/89911b9a6e31a16bb6f9ba1a3789841d to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import os | |
| from sklearn.metrics import classification_report | |
| from mlp import DumbModel, Dataset | |
| def train_model(dataset_dir, model_file, vocab_size): | |
| print(f'Training model from directory {dataset_dir}') | |
| print(f'Vocabulary size: {vocab_size}') | |
| train_dir = os.path.join(dataset_dir, 'train') | |
| test_dir = os.path.join(dataset_dir, 'test') | |
| dset = Dataset(train_dir, test_dir) | |
| X, y = dset.get_train_set() | |
| model = DumbModel(vocab_size=vocab_size) | |
| model.train(X, y) | |
| print(f'Storing model to {model_file}') | |
| model.serialize(model_file) | |
| X_test, y_test = dset.get_test_set() | |
| y_pred = model.predict(X_test) | |
| print(classification_report(y_test, y_pred)) | |
| def ask_model(model_file, question): | |
| print(f'Asking model {model_file} about "{question}"') | |
| model = DumbModel.deserialize(model_file) | |
| y_pred = model.predict_proba([question]) | |
| print(y_pred[0]) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment