Last active
April 15, 2022 08:37
-
-
Save cheesinglee/5351ed778f60403a15ba64e434acb1ac to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/python | |
import functools | |
import numpy as np | |
from matplotlib import pyplot as pp | |
from matplotlib.collections import PolyCollection | |
from matplotlib.cm import get_cmap | |
from bigml.api import BigML | |
api = BigML() | |
# random seed for data generation | |
np.random.seed(97330) | |
def fetch_or_create_whizzml(): | |
""" | |
Fetch the ID of the model-vs-logistic whizzml script, or create it | |
if it does not exist. | |
""" | |
scrs = api.list_scripts('name=model-vs-logistic')['objects'] | |
if len(scrs) > 0: | |
print('Found script: %s' % scrs[0]['resource']) | |
return api.get_script(scrs[0]) | |
else: | |
print('Creating script: model-vs-logistic') | |
libs = api.list_libraries('name=split-dataset')['objects'] | |
if len(libs) > 0: | |
print('Found split-dataset library: %s' % libs[0]['resource']) | |
split_dataset = api.get_library(libs[0]) | |
else: | |
print('Creating library: split-dataset') | |
with open('split-dataset.whizzml') as fid: | |
split_dataset = api.create_library(fid.read(), | |
{'name':'split-dataset'}) | |
with open('model-vs-logistic.whizzml') as fid: | |
return api.create_script(fid.read(), | |
{'name':'model-vs-logistic', | |
'inputs':[{'name':'input-source', | |
'type':'source-id', | |
'description':'identifier for input source'}], | |
'outputs':[{'name':'result', | |
'type':'list', | |
'description':'list of ids for created model and logistic regression, followed by respective f-measures'}], | |
'imports':[split_dataset['resource']]}) | |
def linear_data(mu,sigma,n=1000,cx=-1,cy=1,intercept=0,xmin=0,xmax=100): | |
""" | |
Two classes separated by a straight line. | |
""" | |
# normal vector | |
nx = -cy | |
ny = cx | |
normal = np.array([[nx,ny]]) | |
# normally distributed displacments in normal direction | |
yes_dists = np.random.normal(mu,sigma,[n,1]) | |
no_dists = np.random.normal(-mu,sigma,[n,1]) | |
# sample points along boundary and add displacements | |
yes_points = (np.random.uniform(xmin,xmax,[n,1])*np.array([[cx,cy]]) | |
+ yes_dists*normal) | |
no_points = (np.random.uniform(xmin,xmax,[n,1])*np.array([[cx,cy]]) | |
+ no_dists*normal) | |
# labels | |
yes_points = np.column_stack([yes_points, np.ones(n)]) | |
no_points = np.column_stack([no_points, np.zeros(n)]) | |
data = np.row_stack([yes_points,no_points]) | |
np.savetxt('linsep.csv',data,header='x,y,class', | |
fmt='%.5f, %.5f, %d',comments='') | |
return [data,api.create_source('linsep.csv')] | |
def radial_data(mu0,sigma0,mu1,sigma1,n=1000): | |
""" | |
Two classes arranged in concentric semi-circular segments | |
""" | |
r0 = np.random.normal(mu0,sigma0,[n,1]) | |
r1 = np.random.normal(mu1,sigma1,[n,1]) | |
theta0 = np.random.uniform(0,np.pi,[n,1]) | |
theta1 = np.random.uniform(0,np.pi,[n,1]) | |
p0 = r0*np.column_stack([np.cos(theta0),np.sin(theta0)]) | |
p0 = np.column_stack([p0, np.zeros(n)]) | |
p1 = r1*np.column_stack([np.cos(theta1),np.sin(theta1)]) | |
p1 = np.column_stack([p1, np.ones(n)]) | |
data = np.row_stack([p0,p1]) | |
np.savetxt('radsep.csv',data,header='x,y,class', | |
fmt='%.5f, %.5f, %d', comments='') | |
return [data,api.create_source('radsep.csv')] | |
def draw_data(data): | |
p0 = np.array([row for row in data if row[2] == 0]) | |
p1 = np.array([row for row in data if row[2] == 1]) | |
x0,y0,_ = np.hsplit(p0,[1,2]) | |
x1,y1,_ = np.hsplit(p1,[1,2]) | |
pp.plot(x0,y0,'b.',x1,y1,'r.') | |
pp.axis('tight') | |
def draw_node(node,xmin=None,xmax=None,ymin=None,ymax=None,last_output=None): | |
""" | |
Find the extent and fill-color for a decision tree node. | |
""" | |
pred = node['predicate'] | |
if pred is not True: | |
xlim = pp.xlim() | |
ylim = pp.ylim() | |
xmin = xmin or xlim[0] | |
xmax = xmax or xlim[1] | |
ymin = ymin or ylim[0] | |
ymax = ymax or ylim[1] | |
value = pred['value'] | |
field = pred['field'] | |
operator = pred['operator'] | |
output = node['output'] | |
confidence = node['confidence'] | |
if field == '000000': | |
verts_node = [[value,ymin],[value,ymax]] | |
if operator[0] == '<': | |
verts_node.extend([[xmin,ymax],[xmin,ymin]]) | |
xmax = value | |
elif operator[0] == '>': | |
verts_node.extend([[xmax,ymax],[xmax,ymin]]) | |
xmin = value | |
elif field == '000001': | |
verts_node = [[xmin,value],[xmax,value]] | |
if operator[0] == '<': | |
verts_node.extend([[xmax,ymin],[xmin,ymin]]) | |
ymax = value | |
elif operator[0] == '>': | |
verts_node.extend([[xmax,ymax],[xmin,ymax]]) | |
ymin = value | |
bwr = get_cmap('bwr') | |
if output == last_output: | |
fill_node = 'none' | |
elif output == '0': | |
fill_node = bwr(1-confidence) | |
elif output == '1': | |
fill_node = bwr(confidence) | |
verts = [verts_node] | |
fills = [fill_node] | |
else: | |
output = None | |
verts = [] | |
fills = [] | |
if 'children' in node: | |
for c in node['children']: | |
vs,fs = draw_node(c,xmin,xmax,ymin,ymax,output) | |
verts.extend(vs) | |
fills.extend(fs) | |
return verts, fills | |
def draw_model_splits(res): | |
model = api.get_model(res)['object']['model'] | |
root = model['root'] | |
verts,fills = draw_node(root) | |
pc = PolyCollection(verts,facecolors=fills,edgecolors='k', | |
alpha=0.4,lw=2) | |
pc.set_cmap('bwr') | |
ax = pp.gca() | |
ax.add_collection(pc) | |
def logistic(cx,cy,intercept,x,y): | |
return 1/(1 + np.exp(-(cx*x + cy*y + intercept))) | |
def draw_logistic_boundaries(res): | |
lr = api.get_logistic_regression(res)['object']['logistic_regression'] | |
coeffs = lr['coefficients'] | |
xs = np.array(pp.xlim()) | |
for label,cs in coeffs: | |
cx = cs[0][0] | |
cy = cs[1][0] | |
intercept = cs[2][0] | |
ys = (cx*xs + intercept)/-cy | |
pp.plot(xs,ys,lw=2) | |
ys = np.array(pp.ylim()) | |
logistic_fn = functools.partial(logistic,cx,cy,intercept) | |
x_grid,y_grid = np.meshgrid(np.linspace(xs[0],xs[1]), | |
np.linspace(ys[0],ys[1])) | |
probs = logistic_fn(x_grid,y_grid) | |
if label == '0': probs = 1-probs | |
print(xs) | |
print(ys) | |
pp.imshow(probs,alpha=0.2,cmap='bwr',aspect='auto', | |
extent=(xs[0],xs[1],ys[0],ys[1]),origin='lower') | |
def make_plots(data,src,script,name): | |
src_id = src['resource'] | |
print('Running model-vs-logistic script with source ID %s' % src_id) | |
ex = api.create_execution(script, | |
{'inputs':[['input-source',src_id]]}) | |
api.ok(ex) | |
[model,lr,model_f,lr_f] = ex['object']['execution']['result'] | |
pp.figure() | |
draw_data(data) | |
draw_logistic_boundaries(lr) | |
pp.title('Logistic Regression, f-measure = %f' % lr_f) | |
pp.grid() | |
pp.savefig('lr_boundary_%s.png' % name,transparent=True) | |
pp.figure() | |
draw_data(data) | |
draw_model_splits(model) | |
pp.grid() | |
pp.title('Decision Tree, f-measure = %f' % model_f) | |
pp.savefig('model_boundary_%s.png' % name,transparent=True) | |
if __name__=='__main__': | |
script = fetch_or_create_whizzml() | |
[ldata,lsrc] = linear_data(20,15,cx=3) | |
[rdata,rsrc] = radial_data(10,3,40,5) | |
make_plots(ldata,lsrc,script,'linear') | |
make_plots(rdata,rsrc,script,'radial') | |
pp.show() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
;; Generate an 80/20 training/test split from input data, train and | |
;; evaluate decision tree, and logistic regression. Return ids for | |
;; generated models, and f-measures from evaluations. | |
(define (model-vs-logistic sourcefile) | |
(let (src (create-and-wait-source sourcefile) | |
ds (create-and-wait-dataset src) | |
ids (split-dataset ds 0.8) | |
ds-train (nth ids 0) | |
ds-test (nth ids 1) | |
m (create-and-wait-model ds-train) | |
lr (create-and-wait-logisticregression ds-train) | |
m-eval (create-and-wait-evaluation {"model" m "dataset" ds-test}) | |
lr-eval (create-and-wait-evaluation {"logisticregression" lr "dataset" ds-test})) | |
(list m lr))) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
(define (sample-dataset ds-id rate oob) | |
(create-and-wait-dataset {"sample_rate" rate | |
"origin_dataset" ds-id | |
"out_of_bag" oob | |
"seed" "whizzml-example"})) | |
(define (split-dataset ds-id rate) | |
(list (sample-dataset ds-id rate false) | |
(sample-dataset ds-id rate true))) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment