Created
April 21, 2016 10:49
-
-
Save mehdidc/70940bc569677b9fed07d7247778b600 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
import os | |
import matplotlib as mpl | |
mpl.use('Agg') | |
import matplotlib.pyplot as plt | |
from pyearth import Earth | |
import time | |
np.random.seed(2) | |
def train_model(m, n, p, k, mr, svr, dvr): | |
""" | |
""" | |
r = svr + dvr | |
v = np.random.multinomial(1, (svr, dvr, 1 - (svr + dvr)), size=n) | |
v = v.argmax(axis=1) | |
X = 80 * np.random.uniform(size=(m, n)) | |
nbsvr = np.sum(v==0) | |
nbdvr = np.sum(v==1) | |
W = np.random.uniform(size=(nbsvr, nbdvr)) | |
X[:, v==1] = np.dot(X[:, v==0], W) | |
missing = np.random.uniform(size=X.shape) <= mr | |
y = np.dot(5 * X[:, v==1]**2 + 6 * X[:, v==1]**3, np.random.uniform(size=(nbdvr, p))) | |
X[missing] = None | |
model = Earth(max_terms=k, | |
check_every=1, | |
thresh=0, | |
minspan=1, | |
endspan=1, | |
allow_missing=True) | |
model.fit(X, y) | |
def train(f, params): | |
""" | |
""" | |
values = [] | |
durations = [] | |
for p in zip(*params): | |
print(p) | |
values.append(p) | |
start = time.time() | |
train_model(*p) | |
duration = time.time() - start | |
durations.append(duration) | |
return values, durations | |
def slug(s): | |
return s.lower().replace(' ', '-') | |
m = np.arange(100, 10000, 100) | |
n = np.arange(5, 1000, 10) | |
p = np.arange(1, 10) | |
k = np.arange(5, 20, 5) | |
mr = np.linspace(0, 1, 200) | |
svr = np.linspace(0, 0.9, 200) | |
dvr = np.linspace(0, 0.9, 200) | |
defaults = [500, 20, 1, 15, 0.2, 0.1, 0.1] | |
vars_range = [m, n, p, k, mr, svr, dvr] | |
vars_caption = [ | |
"Number of examples", | |
"Number of variables", | |
"Number of outputs", | |
"Number of terms", | |
"Missing rate (0..1)", | |
"Source dependent variables rate", | |
"Dest dependent variables rate" | |
] | |
if not os.path.exists("missing_out"): | |
os.mkdir("missing_out") | |
for i, (var_range, cap) in enumerate(zip(vars_range, vars_caption)): | |
print(cap) | |
fig = plt.figure() | |
ranges = [[v] * len(var_range) for v in defaults] | |
ranges[i] = var_range | |
_, y = train(train_model, ranges) | |
plt.plot(var_range, y) | |
plt.xlabel(cap) | |
plt.ylabel("duration in sec") | |
plt.savefig("missing_out/{}.png".format(slug(cap))) | |
plt.close(fig) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment