Skip to content

Instantly share code, notes, and snippets.

@thomasjpfan
Last active May 4, 2022 03:14
Show Gist options
  • Select an option

  • Save thomasjpfan/e3a84bac19469651f7ecf6c0bb109bfb to your computer and use it in GitHub Desktop.

Select an option

Save thomasjpfan/e3a84bac19469651f7ecf6c0bb109bfb to your computer and use it in GitHub Desktop.
tree spiltter memory view benchmark
from functools import partial
import argparse
from time import perf_counter
from statistics import mean, stdev
from itertools import product
import csv
from sklearn.tree import DecisionTreeRegressor, DecisionTreeClassifier
from sklearn.datasets import make_classification, make_regression, make_low_rank_matrix
import numpy as np
from scipy.sparse import csc_matrix
parser = argparse.ArgumentParser()
parser.add_argument("results", type=argparse.FileType("w"))
args = parser.parse_args()
benchmark_config = [
(
DecisionTreeRegressor,
"squared_error",
partial(make_regression, n_targets=2),
product(
[15_000],
["numpy", "sparse"],
["best", "random"],
),
),
(
DecisionTreeClassifier,
"gini",
partial(make_classification, n_informative=10, n_classes=5),
product(
[15_000],
["numpy", "sparse"],
["best", "random"],
),
),
]
N_REPEATS = 15
results = {}
writer = csv.DictWriter(
args.results,
fieldnames=[
"criterion",
"n_samples",
"container",
"splitter",
"n_repeat",
"duration",
],
)
writer.writeheader()
for Klass, criterion, make_data, items in benchmark_config:
for config in items:
n_samples, container, splitter = config
klass_results = []
for n_repeat in range(N_REPEATS):
X, y = make_data(n_samples=n_samples, random_state=n_repeat, n_features=100)
tree = Klass(random_state=n_repeat, criterion=criterion, splitter=splitter)
if container == "sparse":
X = csc_matrix(X, dtype=np.float32)
start = perf_counter()
tree.fit(X, y)
duration = perf_counter() - start
klass_results.append(duration)
writer.writerow(
{
"criterion": criterion,
"n_samples": n_samples,
"container": container,
"splitter": splitter,
"n_repeat": n_repeat,
"duration": duration,
}
)
results_mean, results_stdev = mean(klass_results), stdev(klass_results)
print(
f"criterion={criterion} container={container} n_samples={n_samples} "
f"splitter={splitter} with"
f" {results_mean:.3f} +/- {results_stdev:.3f}"
)
import matplotlib.pyplot as plt
from pathlib import Path
import pandas as pd
import seaborn as sns
plt.rc('font', size=12)
pr = pd.read_csv("data/pr_mv_splitter.csv")
main = pd.read_csv("data/main_mv_splitter.csv")
df = pd.concat([pr.assign(branch="pr"), main.assign(branch="main")])
grouped = list(df.groupby(["container", "splitter", "criterion"]))
fig, axis = plt.subplots(2, 4, figsize=(14, 6), constrained_layout=True)
for ((container, splitter, criterion), subset), ax in zip(grouped, axis.reshape(-1)):
sns.violinplot(data=subset, y="duration", x="branch", ax=ax)
ax.set_title(f"{container} | {splitter} | {criterion}")
ax.set_xlabel("")
for ax in axis[:, 1:].ravel():
ax.set_ylabel("")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment