Created
December 2, 2019 19:14
-
-
Save thomasjpfan/c18241576c84a812bc67e14e8ebafdd8 to your computer and use it in GitHub Desktop.
Benchmarking script
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import sklearn | |
| import numpy as np | |
| import scipy | |
| import csv | |
| import argparse | |
| from pathlib import Path | |
| import openml | |
| from openml.exceptions import OpenMLRunsExistError | |
| from sklearn.experimental import enable_hist_gradient_boosting | |
| from sklearn.ensemble import HistGradientBoostingClassifier | |
| from sklearn.model_selection import RandomizedSearchCV | |
| from sklearn.model_selection import HalvingRandomSearchCV | |
| param_distributions = { # is this reasonable? | |
| 'max_depth': list(range(5, 10)) + [1000], | |
| 'max_leaf_nodes': list(range(30, 40)), | |
| 'min_samples_leaf': [2] + list(range(20, 30)), | |
| 'learning_rate': [.01, .1, 1], | |
| 'l2_regularization': [0, .01, .1], | |
| } | |
| clf = HistGradientBoostingClassifier() | |
| n_candidates = 100 # is this too much? | |
| n_jobs = 3 | |
| models = { | |
| 'rs': | |
| RandomizedSearchCV(clf, | |
| param_distributions=param_distributions, | |
| n_iter=n_candidates, | |
| random_state=0, | |
| n_jobs=n_jobs, | |
| verbose=1), | |
| 'sh_defaults': | |
| HalvingRandomSearchCV(clf, | |
| param_distributions=param_distributions, | |
| n_candidates=n_candidates, | |
| random_state=0, | |
| n_jobs=n_jobs, | |
| verbose=1), | |
| 'sh_force_exhaust': | |
| HalvingRandomSearchCV(clf, | |
| param_distributions=param_distributions, | |
| n_candidates=n_candidates, | |
| random_state=0, | |
| n_jobs=n_jobs, | |
| force_exhaust_resources=True, | |
| verbose=1), | |
| 'sh_max_iter': | |
| HalvingRandomSearchCV(clf, | |
| param_distributions=param_distributions, | |
| n_candidates=n_candidates, | |
| random_state=0, | |
| n_jobs=n_jobs, | |
| force_exhaust_resources=True, | |
| resource='max_iter', | |
| max_resources=100, | |
| verbose=1) | |
| } | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument('search', choices=models.keys()) | |
| args = parser.parse_args() | |
| root_path = Path(args.search) | |
| root_path.mkdir(exist_ok=True) | |
| results_path = root_path / "task_list.csv" | |
| col_names = ['task_id', 'run_id'] | |
| completed_tasks = {} | |
| if not results_path.exists(): | |
| with results_path.open('w') as f: | |
| writer = csv.DictWriter(f, fieldnames=col_names) | |
| writer.writeheader() | |
| else: | |
| with results_path.open('r') as f: | |
| reader = csv.DictReader(f, fieldnames=col_names) | |
| next(reader) # skip header | |
| completed_tasks = set(int(row['task_id']) for row in reader) | |
| benchmark_suite = openml.study.get_suite('OpenML-CC18') | |
| for task_id in benchmark_suite.tasks: | |
| if task_id in completed_tasks: | |
| print('TASK_ID', task_id, 'was completed, skipping') | |
| continue | |
| print('TASK_ID running', task_id) | |
| task = openml.tasks.get_task(task_id) | |
| try: | |
| model = models[args.search] | |
| run = openml.runs.run_model_on_task(model, task) | |
| except OpenMLRunsExistError: | |
| continue | |
| # save | |
| print('TASK_ID finished - uploading', task_id) | |
| run.publish() | |
| print('TASK_ID saving locally', task_id) | |
| with results_path.open('a') as f: | |
| writer = csv.DictWriter(f, fieldnames=col_names) | |
| writer.writerow({'task_id': task_id, 'run_id': run.run_id}) | |
| run.to_filesystem(str(root_path / "task_{}".format(task_id))) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment