Skip to content

Instantly share code, notes, and snippets.

@mehdidc
Created May 22, 2016 23:13
Show Gist options
  • Save mehdidc/a8fe1758629f6f9125acb9883d895792 to your computer and use it in GitHub Desktop.
Save mehdidc/a8fe1758629f6f9125acb9883d895792 to your computer and use it in GitHub Desktop.
import numpy as np
import time
from joblib import Memory
import pandas as pd
from bokeh.charts import show, Bar
from bokeh.io import output_file, vplot
from pyearth import Earth
from pyearthnew import Earth as EarthNew
cachedir = 'tmp'
memory = Memory(cachedir=cachedir, verbose=0)
np.random.seed(2)
def train_model(CLS, m, n, p, k, mr):
X = 80 * np.random.uniform(size=(m, n))
missing = np.random.uniform(size=X.shape) <= mr
y = 5 * X[:, 0] ** 2 + 6 * np.sin(X[:, 1]) ** 3
allow_missing = True if mr > 0 else False
X[missing] = None
model = CLS(max_terms=k,
check_every=1,
thresh=0,
minspan=1,
endspan=1,
allow_missing=allow_missing)
model.fit(X, y)
@memory.cache
def train(CLS, f, params):
values = []
durations = []
for p in zip(*params):
print(p)
values.append(p)
start = time.time()
train_model(CLS, *p)
duration = time.time() - start
durations.append(duration)
return values, durations
m = np.arange(100, 10000, 100)
n = np.arange(5, 1000, 10)
p = np.arange(1, 10, 1)
k = np.arange(5, 80, 5)
defaults = [500, 20, 1, 15, 0]
vars_range = [m, n, p, k]
vars_caption = [
"Number of examples",
"Number of variables",
"Number of outputs",
"Number of terms",
]
output_file('bench.html')
charts = []
for missing_rate in (0, 0.3):
for i, (var_range, cap) in enumerate(zip(vars_range, vars_caption)):
ranges = [[v] * len(var_range) for v in defaults]
ranges[i] = var_range
ranges[-1] = [missing_rate] * len(var_range)
_, durations = train(Earth, train_model, ranges)
_, durations_new = train(EarthNew, train_model, ranges)
df = pd.DataFrame({cap: list(var_range) + list(var_range),
'time(sec)': durations + durations_new,
'model': ['Earth'] * len(var_range) + ['EarthNew'] * len(var_range)})
title = 'Duration (sec) as {} increases {}'.format(cap, '(with missing values)' if missing_rate != 0 else '')
chart = Bar(df, label=cap, values='time(sec)',
stack='model',
legend='top_right',
plot_width=1400, plot_height=800, title=title)
charts.append(chart)
fig = vplot(*charts)
show(fig)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment