Created
January 31, 2022 17:22
-
-
Save thomasjpfan/21859af29d1f2d0057dde141daf847dd to your computer and use it in GitHub Desktop.
simple_tree_bench.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import json | |
| from collections import defaultdict | |
| import time | |
| from sklearn.tree import DecisionTreeClassifier | |
| from sklearn.datasets import make_classification | |
| from argparse import ArgumentParser | |
| import argparse | |
| from tqdm import tqdm | |
| parser = ArgumentParser() | |
| parser.add_argument("results", type=argparse.FileType("w")) | |
| args = parser.parse_args() | |
| n_samples = [1_000, 5_000, 10_000, 50_000] | |
| n_features = 100 | |
| n_repeats = 30 | |
| results = defaultdict(list) | |
| for n_sample in tqdm(n_samples, desc=" n_sample", position=0): | |
| for n_repeat in tqdm(range(n_repeats), desc=" n_repeat", position=1, leave=False): | |
| X, y = make_classification( | |
| random_state=n_repeat, n_features=n_features, n_samples=n_sample | |
| ) | |
| dc = DecisionTreeClassifier(random_state=n_repeat) | |
| start = time.perf_counter() | |
| dc.fit(X, y) | |
| results[n_sample].append(time.perf_counter() - start) | |
| json.dump(results, args.results) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment