Skip to content

Instantly share code, notes, and snippets.

@mehdidc
Created April 21, 2016 10:49
Show Gist options
  • Save mehdidc/70940bc569677b9fed07d7247778b600 to your computer and use it in GitHub Desktop.
Save mehdidc/70940bc569677b9fed07d7247778b600 to your computer and use it in GitHub Desktop.
import numpy as np
import os
import matplotlib as mpl
mpl.use('Agg')
import matplotlib.pyplot as plt
from pyearth import Earth
import time
np.random.seed(2)
def train_model(m, n, p, k, mr, svr, dvr):
"""
"""
r = svr + dvr
v = np.random.multinomial(1, (svr, dvr, 1 - (svr + dvr)), size=n)
v = v.argmax(axis=1)
X = 80 * np.random.uniform(size=(m, n))
nbsvr = np.sum(v==0)
nbdvr = np.sum(v==1)
W = np.random.uniform(size=(nbsvr, nbdvr))
X[:, v==1] = np.dot(X[:, v==0], W)
missing = np.random.uniform(size=X.shape) <= mr
y = np.dot(5 * X[:, v==1]**2 + 6 * X[:, v==1]**3, np.random.uniform(size=(nbdvr, p)))
X[missing] = None
model = Earth(max_terms=k,
check_every=1,
thresh=0,
minspan=1,
endspan=1,
allow_missing=True)
model.fit(X, y)
def train(f, params):
"""
"""
values = []
durations = []
for p in zip(*params):
print(p)
values.append(p)
start = time.time()
train_model(*p)
duration = time.time() - start
durations.append(duration)
return values, durations
def slug(s):
return s.lower().replace(' ', '-')
m = np.arange(100, 10000, 100)
n = np.arange(5, 1000, 10)
p = np.arange(1, 10)
k = np.arange(5, 20, 5)
mr = np.linspace(0, 1, 200)
svr = np.linspace(0, 0.9, 200)
dvr = np.linspace(0, 0.9, 200)
defaults = [500, 20, 1, 15, 0.2, 0.1, 0.1]
vars_range = [m, n, p, k, mr, svr, dvr]
vars_caption = [
"Number of examples",
"Number of variables",
"Number of outputs",
"Number of terms",
"Missing rate (0..1)",
"Source dependent variables rate",
"Dest dependent variables rate"
]
if not os.path.exists("missing_out"):
os.mkdir("missing_out")
for i, (var_range, cap) in enumerate(zip(vars_range, vars_caption)):
print(cap)
fig = plt.figure()
ranges = [[v] * len(var_range) for v in defaults]
ranges[i] = var_range
_, y = train(train_model, ranges)
plt.plot(var_range, y)
plt.xlabel(cap)
plt.ylabel("duration in sec")
plt.savefig("missing_out/{}.png".format(slug(cap)))
plt.close(fig)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment