Skip to content

Instantly share code, notes, and snippets.

@thomasjpfan
Created January 31, 2022 17:22
Show Gist options
  • Select an option

  • Save thomasjpfan/21859af29d1f2d0057dde141daf847dd to your computer and use it in GitHub Desktop.

Select an option

Save thomasjpfan/21859af29d1f2d0057dde141daf847dd to your computer and use it in GitHub Desktop.
simple_tree_bench.py
import json
from collections import defaultdict
import time
from sklearn.tree import DecisionTreeClassifier
from sklearn.datasets import make_classification
from argparse import ArgumentParser
import argparse
from tqdm import tqdm
parser = ArgumentParser()
parser.add_argument("results", type=argparse.FileType("w"))
args = parser.parse_args()
n_samples = [1_000, 5_000, 10_000, 50_000]
n_features = 100
n_repeats = 30
results = defaultdict(list)
for n_sample in tqdm(n_samples, desc=" n_sample", position=0):
for n_repeat in tqdm(range(n_repeats), desc=" n_repeat", position=1, leave=False):
X, y = make_classification(
random_state=n_repeat, n_features=n_features, n_samples=n_sample
)
dc = DecisionTreeClassifier(random_state=n_repeat)
start = time.perf_counter()
dc.fit(X, y)
results[n_sample].append(time.perf_counter() - start)
json.dump(results, args.results)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment