|
import time |
|
import argparse |
|
|
|
import matplotlib.pyplot as plt |
|
import numpy as np |
|
from sklearn.datasets import fetch_20newsgroups_vectorized |
|
from sklearn.tree import DecisionTreeClassifier |
|
from joblib import dump, load |
|
from joblib import Parallel, delayed |
|
|
|
RUNS_PER_ALPHA = 5 |
|
|
|
|
|
def timeit(f, *args, **kwargs): |
|
start = time.time() |
|
output = f(*args, **kwargs) |
|
end = time.time() |
|
return output, end - start |
|
|
|
|
|
def prune_clf(ccp_alpha, model_fn): |
|
clf = load(model_fn) |
|
clf.set_params(ccp_alpha=ccp_alpha) |
|
_, prune_dur = timeit(clf.prune_tree) |
|
return prune_dur, clf.tree_.max_depth |
|
|
|
|
|
def measure(version): |
|
X, y = fetch_20newsgroups_vectorized(return_X_y=True) |
|
|
|
clf = DecisionTreeClassifier(random_state=42) |
|
|
|
print("Overfitting model") |
|
_, fit_dur = timeit(clf.fit, X, y) |
|
print(f"Overfitted in {fit_dur} seconds") |
|
|
|
overfit_model_fn = f"{version}_overfit.pt" |
|
dump(clf, overfit_model_fn) |
|
|
|
alphas = np.linspace(0, 1e-3, 10) |
|
alphas = [0, 2e-4, 3e-4, 4e-4, 5e-4, 6e-4, 9e-4, 4e-3, 4e-2] |
|
|
|
# 3 columns = tree_depth, mean, std |
|
results = np.empty((len(alphas), 3), dtype=np.float32) |
|
|
|
for i, ccp_alpha in enumerate(alphas): |
|
print(f"Pruning for ccp_alpha: {ccp_alpha}") |
|
|
|
ccp_alpha_results = Parallel(n_jobs=-1)( |
|
delayed(prune_clf)(ccp_alpha, overfit_model_fn) |
|
for _ in range(RUNS_PER_ALPHA)) |
|
|
|
ccp_alpha_results = np.array(ccp_alpha_results) |
|
prune_frac = ccp_alpha_results[:, 0] / fit_dur |
|
results[i, 0] = np.mean(ccp_alpha_results[:, 1]) |
|
results[i, 1] = np.mean(prune_frac) |
|
results[i, 2] = np.std(prune_frac) |
|
print(results[i, :]) |
|
|
|
np.savetxt(f"{version}_results.txt", results) |
|
|
|
|
|
def plot(version): |
|
results = np.loadtxt(f"{version}_results.txt") |
|
tree_depth = results[:, 0].astype(np.int) |
|
prune_mean = results[:, 1] |
|
prune_std = results[:, 2] |
|
|
|
fig, ax = plt.subplots(figsize=(12, 8)) |
|
idx = np.arange(0, len(results)) |
|
ax.errorbar(idx, prune_mean, yerr=prune_std) |
|
ax.set_xticklabels([0] + list(tree_depth)) |
|
ax.tick_params(axis='both', which='major', labelsize=20) |
|
ax.set_title( |
|
f"Fraction of time added to fit when pruning - {version}", size=20) |
|
ax.set_xlabel("Pruned tree depth", size=20) |
|
ax.set_ylabel("Fraction of time added to fit", size=20) |
|
fig.savefig(f"{version}_plot.png") |
|
|
|
|
|
if __name__ == "__main__": |
|
parser = argparse.ArgumentParser(description="Bench prune tree") |
|
parser.add_argument("version", type=str) |
|
parser.add_argument("function", type=str) |
|
|
|
args = parser.parse_args() |
|
|
|
if args.function == "measure": |
|
measure(args.version) |
|
elif args.function == "plot": |
|
plot(args.version) |