Skip to content

Instantly share code, notes, and snippets.

@thomasjpfan
Last active February 5, 2019 23:04
Show Gist options
  • Select an option

  • Save thomasjpfan/ee04160eef4eae7647cc20f564c536a1 to your computer and use it in GitHub Desktop.

Select an option

Save thomasjpfan/ee04160eef4eae7647cc20f564c536a1 to your computer and use it in GitHub Desktop.
Cost Complexity Pruning benchmarks

How to run benchmarks

Run pure python version

  1. Set up a virtual env.
  2. git clone -b ccp_prune_tree https://github.com/thomasjpfan/scikit-learn scikit-learn-ccp-python
  3. cd scikit-learn-ccp-python
  4. git checkout 25910e085cf7bb0a98ee33c050fa9233e247e523
  5. Install scikit-learn
  6. Go to directory with bench_prune_tree.py
  7. python bench_prune_tree.py python measure
  8. python bench_prune_tree.py python plot

Run cython version

  1. Set up a virtual env.
  2. git clone -b ccp_prune_tree https://github.com/thomasjpfan/scikit-learn scikit-learn-ccp-cython
  3. cd scikit-learn-ccp-cython
  4. git checkout e59b662213ed1a39ac8f96f2adf8aa3197dde07d
  5. Install scikit-learn
  6. Go to directory with bench_prune_tree.py
  7. python bench_prune_tree.py cython measure
  8. python bench_prune_tree.py cython plot
import time
import argparse
import matplotlib.pyplot as plt
import numpy as np
from sklearn.datasets import fetch_20newsgroups_vectorized
from sklearn.tree import DecisionTreeClassifier
from joblib import dump, load
from joblib import Parallel, delayed
RUNS_PER_ALPHA = 5
def timeit(f, *args, **kwargs):
start = time.time()
output = f(*args, **kwargs)
end = time.time()
return output, end - start
def prune_clf(ccp_alpha, model_fn):
clf = load(model_fn)
clf.set_params(ccp_alpha=ccp_alpha)
_, prune_dur = timeit(clf.prune_tree)
return prune_dur, clf.tree_.max_depth
def measure(version):
X, y = fetch_20newsgroups_vectorized(return_X_y=True)
clf = DecisionTreeClassifier(random_state=42)
print("Overfitting model")
_, fit_dur = timeit(clf.fit, X, y)
print(f"Overfitted in {fit_dur} seconds")
overfit_model_fn = f"{version}_overfit.pt"
dump(clf, overfit_model_fn)
alphas = np.linspace(0, 1e-3, 10)
alphas = [0, 2e-4, 3e-4, 4e-4, 5e-4, 6e-4, 9e-4, 4e-3, 4e-2]
# 3 columns = tree_depth, mean, std
results = np.empty((len(alphas), 3), dtype=np.float32)
for i, ccp_alpha in enumerate(alphas):
print(f"Pruning for ccp_alpha: {ccp_alpha}")
ccp_alpha_results = Parallel(n_jobs=-1)(
delayed(prune_clf)(ccp_alpha, overfit_model_fn)
for _ in range(RUNS_PER_ALPHA))
ccp_alpha_results = np.array(ccp_alpha_results)
prune_frac = ccp_alpha_results[:, 0] / fit_dur
results[i, 0] = np.mean(ccp_alpha_results[:, 1])
results[i, 1] = np.mean(prune_frac)
results[i, 2] = np.std(prune_frac)
print(results[i, :])
np.savetxt(f"{version}_results.txt", results)
def plot(version):
results = np.loadtxt(f"{version}_results.txt")
tree_depth = results[:, 0].astype(np.int)
prune_mean = results[:, 1]
prune_std = results[:, 2]
fig, ax = plt.subplots(figsize=(12, 8))
idx = np.arange(0, len(results))
ax.errorbar(idx, prune_mean, yerr=prune_std)
ax.set_xticklabels([0] + list(tree_depth))
ax.tick_params(axis='both', which='major', labelsize=20)
ax.set_title(
f"Fraction of time added to fit when pruning - {version}", size=20)
ax.set_xlabel("Pruned tree depth", size=20)
ax.set_ylabel("Fraction of time added to fit", size=20)
fig.savefig(f"{version}_plot.png")
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Bench prune tree")
parser.add_argument("version", type=str)
parser.add_argument("function", type=str)
args = parser.parse_args()
if args.function == "measure":
measure(args.version)
elif args.function == "plot":
plot(args.version)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment