Created
March 3, 2020 23:27
-
-
Save guillefix/a9b71f2551e1a2240dc206db9a2730b7 to your computer and use it in GitHub Desktop.
some useful plots for generalization error / complexity data using plotly
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import plotly.plotly as py | |
from plotly.graph_objs import * | |
def nice_2dhist(x,y,nbins,title='title',xlabel='x',ylabel='y',filename='nice-hist.png'): | |
py.sign_in(username='guillefix', api_key='mflFpFhvUHtAfpGXbKkv') | |
trace1 = { | |
"x": x, | |
"y": y, | |
"marker": { | |
"color": "rgb(255, 255, 255)", | |
"line": {"width": 0.5}, | |
"opacity": 0.4, | |
"size": 4 | |
}, | |
"mode": "markers", | |
"name": "points", | |
"opacity": 0.75, | |
"text": [], | |
"type": "scatter", | |
"uid": "eb94b3" | |
} | |
trace2 = { | |
"x": x, | |
"y": y, | |
"autocolorscale": False, | |
"colorscale": [ | |
[0, "rgb(8, 29, 88)"], [0.125, "rgb(37, 52, 148)"], [0.25, "rgb(34, 94, 168)"], [0.375, "rgb(29, 145, 192)"], [0.5, "rgb(65, 182, 196)"], [0.625, "rgb(127, 205, 187)"], [0.75, "rgb(199, 233, 180)"], [0.875, "rgb(237, 248, 217)"], [1, "rgb(255, 255, 217)"]], | |
"contours": { | |
"coloring": "fill", | |
"end": 80.05, | |
"showlines": True, | |
"size": 5, | |
"start": 5 | |
}, | |
"name": "density", | |
"ncontours": 20, | |
"reversescale": False, | |
"showscale": False, | |
"type": "histogram2dcontour", | |
"uid": "b20cd7", | |
"xbins": { | |
"end": max(x), | |
"size": (max(x)-min(x))/nbins, | |
"start": min(x) | |
}, | |
"ybins": { | |
"end": max(y), | |
"size": (max(y)-min(y))/nbins, | |
"start": min(y) | |
}, | |
"zmax": 83, | |
"zmin": 0 | |
} | |
trace3 = { | |
"x": x, | |
"marker": {"color": "rgb(31, 119, 180)"}, | |
"name": "x density", | |
"type": "histogram", | |
"uid": "70efa7", | |
"xbins": { | |
"end": max(x), | |
"size": (max(x)-min(x))/nbins, | |
"start": min(x) | |
}, | |
"yaxis": "y2" | |
} | |
trace4 = { | |
"y": y, | |
"marker": {"color": "rgb(33, 113, 181)"}, | |
"name": "y density", | |
"type": "histogram", | |
"uid": "73ca31", | |
"xaxis": "x2", | |
"ybins": { | |
"end": max(y), | |
"size": (max(y)-min(y))/nbins, | |
"start": min(y) | |
} | |
} | |
data = Data([trace1, trace2, trace3, trace4]) | |
layout = { | |
"autosize": False, | |
"bargap": 0, | |
"height": 700, | |
"hovermode": "closest", | |
"margin": {"t": 50}, | |
"paper_bgcolor": "rgb(249, 249, 249)", | |
"plot_bgcolor": "rgb(249, 249, 249)", | |
"showlegend": False, | |
"title": title, | |
"width": 800, | |
"xaxis": { | |
"autorange": True, | |
"domain": [0, 0.85], | |
"range": [min(x),max(x)], | |
"showgrid": False, | |
"title": xlabel, | |
"type": "linear", | |
"zeroline": False | |
}, | |
"xaxis2": { | |
"autorange": True, | |
"domain": [0.85, 1], | |
"range": [min(x),max(x)], | |
"showgrid": False, | |
"title": "", | |
"type": "linear", | |
"zeroline": False | |
}, | |
"yaxis": { | |
"autorange": True, | |
"domain": [0, 0.85], | |
"range": [min(y),max(y)], | |
"showgrid": False, | |
"title": ylabel, | |
"type": "linear", | |
"zeroline": False | |
}, | |
"yaxis2": { | |
"autorange": True, | |
"domain": [0.85, 1], | |
"range": [min(y),max(y)], | |
"showgrid": False, | |
"title": "", | |
"type": "linear", | |
"zeroline": False | |
} | |
} | |
fig = Figure(data=data, layout=layout) | |
# plot_url = py.plot(fig) | |
# py.iplot(fig, filename='Comp_gen_erro_hist') | |
py.image.save_as(fig, filename=filename) | |
# import matplotlib.pyplot as plt | |
# | |
# %matplotlib inline | |
# | |
# plt.scatter(final_LZs[0], gen_errors[0]) | |
# | |
# %matplotlib | |
# plt.clf() | |
# | |
# def forceAspect(ax,aspect=1): | |
# im = ax.get_images() | |
# extent = im[0].get_extent() | |
# ax.set_aspect(abs((extent[1]-extent[0])/(extent[3]-extent[2]))/aspect) | |
# | |
# | |
# idx=11 | |
# heatmap, xedges, yedges = np.histogram2d(final_entss[idx], gen_errors[idx], bins=20) | |
# extent = [xedges[0], xedges[-1], yedges[0], yedges[-1]] | |
# | |
# fig = plt.figure() | |
# ax = fig.add_subplot(111) | |
# | |
# # plt.ylim(yedges[0], yedges[-1]) | |
# # plt.figure(figsize=(3,3)) | |
# ax.imshow(heatmap.T, extent=extent, origin='lower') | |
# forceAspect(ax,aspect=1) | |
# fig.show() | |
# | |
# from matplotlib import cm as CM | |
# | |
# idx=3 | |
# plt.hexbin(final_entss[idx], gen_errors[idx],gridsize=20,cmap=CM.jet, bins=None) | |
# plt.show() | |
# x=final_entss[idx] | |
# y=gen_errors[idx] | |
# nbins=10 | |
def nice_2dhist_double(xa,ya,xb,yb,nbins,title='title',xlabel='x',ylabel='y',filename='nice-hist.png'): | |
x = xa+xb | |
y = ya+yb | |
# py.sign_in(username='guillefix', api_key='mflFpFhvUHtAfpGXbKkv') | |
# py.sign_in(username='guillefix3', api_key='ZxKRSVBk0GZnzfFsmdAj') | |
py.sign_in(username='guillefix4', api_key='Z9EakzS6cVPL0DN6SJWJ') | |
# py.sign_in(username='guillefix5',api_key='lrZZKCPNIOwDtvgazKAL') | |
fontsize = 26 | |
trace1a = { | |
"x": xa, | |
"y": ya, | |
"marker": { | |
"symbol": "circle", | |
"color": "rgb(31, 119, 180)", | |
"line": {"width": 1}, | |
"opacity": 0.4, | |
"size": 8 | |
}, | |
"mode": "markers", | |
"name": "Neural network", | |
"opacity": 0.75, | |
"text": [], | |
"type": "scatter", | |
"uid": "eb94b3" | |
} | |
trace1b = { | |
"x": xb, | |
"y": yb, | |
"marker": { | |
"symbol": "diamond", | |
"color": "red", | |
"line": {"width": 0.1}, | |
"opacity": 0.4, | |
"size": 8 | |
}, | |
"mode": "markers", | |
"name": "Unbiased learner", | |
"opacity": 0.75, | |
"text": [], | |
"type": "scatter", | |
"uid": "eb94b3" | |
} | |
# trace2a = { | |
# "x": xa, | |
# "y": ya, | |
# "autocolorscale": False, | |
# "colorscale": [[0, "rgb(8, 29, 88)"], [0.125, "rgb(37, 52, 148)"], [0.25, "rgb(34, 94, 168)"], [0.375, "rgb(29, 145, 192)"], [0.5, "rgb(65, 182, 196)"], [0.625, "rgb(127, 205, 187)"], [0.75, "rgb(199, 233, 180)"], [0.875, "rgb(237, 248, 217)"], [1, "rgb(255, 255, 217)"]], | |
# "contours": { | |
# "coloring": "fill", | |
# "end": 30, | |
# "showlines": True, | |
# "size": 5, | |
# "start": 5 | |
# }, | |
# "name": "density", | |
# "ncontours": 10, | |
# "reversescale": False, | |
# "showscale": False, | |
# "type": "histogram2dcontour", | |
# "uid": "b20cd7", | |
# "xbins": { | |
# "end": max(xa)+0.1, | |
# "size": (max(xa)-min(xa))/nbins, | |
# "start": min(xa)-0.1 | |
# }, | |
# "ybins": { | |
# "end": max(ya)+0.1, | |
# "size": (max(ya)-min(ya))/nbins, | |
# "start": min(ya)-0.1 | |
# }, | |
# "zmax": 30, | |
# "zmin": 0 | |
# } | |
trace2b = { | |
"x": x, | |
"y": y, | |
"autocolorscale": False, | |
"colorscale": 'Greys',# | |
# "colorscale": [[0, "rgb(8, 29, 88)"], [0.125, "rgb(37, 52, 148)"], [0.25, "rgb(34, 94, 168)"], [0.375, "rgb(29, 145, 192)"], [0.5, "rgb(65, 182, 196)"], [0.625, "rgb(127, 205, 187)"], [0.75, "rgb(199, 233, 180)"], [0.875, "rgb(237, 248, 217)"], [1, "rgb(255, 255, 255)"]], | |
"contours": { | |
"coloring": "fill", | |
"end": 100, | |
"showlines": False, | |
"size": 5, | |
"start": 5 | |
}, | |
"name": "density", | |
"ncontours": 50, | |
"reversescale": True, | |
"showscale": False, | |
"type": "histogram2dcontour", | |
"uid": "b20cd7", | |
"xbins": { | |
"end": max(x)*1.1, | |
"size": (max(x)-min(x))/nbins, | |
"start": min(x)-max(x)*0.1 | |
}, | |
"ybins": { | |
"end": max(y)*1.1, | |
"size": (max(y)-min(y))/nbins, | |
"start": min(y)-0.05 | |
}, | |
"zmax": 50, | |
"zmin": 0 | |
} | |
trace3a = { | |
"x": xa, | |
"marker": {"color": "rgb(31, 119, 180)"}, | |
"name": "Entropy histogram, NN", | |
"type": "histogram", | |
"showlegend": False, | |
"uid": "70efa7", | |
"xbins": { | |
"end": max(x), | |
"size": (max(x)-min(x))/nbins, | |
"start": min(x) | |
}, | |
"yaxis": "y2" | |
} | |
trace3b = { | |
"x": xb, | |
"marker": {"color": "red"}, | |
"name": "Entropy histogram, unbiased", | |
"type": "histogram", | |
"showlegend": False, | |
"uid": "70efa7", | |
"xbins": { | |
"end": max(x), | |
"size": (max(x)-min(x))/nbins, | |
"start": min(x) | |
}, | |
"yaxis": "y2" | |
} | |
trace4a = { | |
"y": ya, | |
"marker": {"color": "rgb(33, 113, 181)"}, | |
"name": "Error histogram, NN", | |
"type": "histogram", | |
"showlegend": False, | |
"uid": "73ca31", | |
"xaxis": "x2", | |
"ybins": { | |
"end": max(y), | |
"size": (max(y)-min(y))/nbins, | |
"start": min(y) | |
} | |
} | |
trace4b = { | |
"y": yb, | |
"marker": {"color": "red"}, | |
"name": "Error histogram, unbiased", | |
"type": "histogram", | |
"showlegend": False, | |
"uid": "73ca31", | |
"xaxis": "x2", | |
"ybins": { | |
"end": max(y), | |
"size": (max(y)-min(y))/nbins, | |
"start": min(y) | |
} | |
} | |
data = Data([trace1a, trace3a, trace4a, trace1b, trace2b, trace3b, trace4b]) | |
layout = { | |
"autosize": False, | |
"bargap": 0, | |
"height": 700, | |
"hovermode": "closest", | |
"margin": {"t": 40}, | |
# "paper_bgcolor": "rgb(249, 249, 249)", | |
# "plot_bgcolor": "rgb(249, 249, 249)", | |
"showlegend": True, | |
"font":dict(size=fontsize, color='black'), | |
"legend": { | |
"x": 0.07, | |
"y": 0.8, | |
"bgcolor":'#E2E2E2', | |
"bordercolor":'#FFFFFF', | |
"borderwidth":2 | |
# font:dict( | |
# family='sans-serif', | |
# size=12, | |
# color='#000' | |
# ), | |
}, | |
"title": title, | |
"width": 800, | |
"xaxis": { | |
"autorange": True, | |
"domain": [0, 0.85], | |
"range": [min(x),max(x)], | |
"showgrid": False, | |
"title": xlabel, | |
"type": "linear", | |
"zeroline": False, | |
"showline":True, | |
"titlefont":dict(size=fontsize, color='black') | |
}, | |
"xaxis2": { | |
"autorange": True, | |
"domain": [0.85, 1], | |
"range": [min(x),max(x)], | |
"showgrid": False, | |
"showticklabels":False, | |
"title": "", | |
"type": "linear", | |
"zeroline": True | |
}, | |
"yaxis": { | |
"autorange": True, | |
"domain": [0, 0.85], | |
"range": [min(y),max(y)], | |
"showgrid": False, | |
"title": ylabel, | |
"type": "linear", | |
"zeroline": False, | |
"showline":True, | |
"titlefont":dict(size=fontsize, color='black') | |
}, | |
"yaxis2": { | |
"autorange": True, | |
"domain": [0.85, 1], | |
"range": [min(y),max(y)], | |
"showgrid": False, | |
"showticklabels":False, | |
"title": "", | |
"type": "linear", | |
"zeroline": True | |
} | |
} | |
fig = Figure(data=data, layout=layout) | |
# plot_url = py.plot(fig) | |
# py.iplot(fig, filename='Comp_gen_erro_hist') | |
py.image.save_as(fig, filename=filename) | |
def nice_2dhist_double_tight(xa,ya,xb,yb,nbins,title='title',xlabel='x',ylabel='y',filename='nice-hist.png'): | |
x = xa+xb | |
y = ya+yb | |
# py.sign_in(username='guillefix', api_key='mflFpFhvUHtAfpGXbKkv') | |
# py.sign_in(username='guillefix3', api_key='ZxKRSVBk0GZnzfFsmdAj') | |
py.sign_in(username='guillefix4', api_key='Z9EakzS6cVPL0DN6SJWJ') | |
# py.sign_in(username='guillefix5',api_key='lrZZKCPNIOwDtvgazKAL') | |
trace1a = { | |
"x": xa, | |
"y": ya, | |
"marker": { | |
"symbol": "circle", | |
"color": "rgb(31, 119, 180)", | |
"line": {"width": 1}, | |
"opacity": 0.4, | |
"size": 8 | |
}, | |
"mode": "markers", | |
"name": "Neural network", | |
"opacity": 0.75, | |
"text": [], | |
"type": "scatter", | |
"uid": "eb94b3" | |
} | |
trace1b = { | |
"x": xb, | |
"y": yb, | |
"marker": { | |
"symbol": "diamond", | |
"color": "red", | |
"line": {"width": 0.1}, | |
"opacity": 0.4, | |
"size": 8 | |
}, | |
"mode": "markers", | |
"name": "Unbiased learner", | |
"opacity": 0.75, | |
"text": [], | |
"type": "scatter", | |
"uid": "eb94b3" | |
} | |
trace2b = { | |
"x": x, | |
"y": y, | |
"autocolorscale": False, | |
"colorscale": 'Greys',# | |
"contours": { | |
"coloring": "fill", | |
"end": 100, | |
"showlines": False, | |
"size": 5, | |
"start": 5 | |
}, | |
"name": "density", | |
"ncontours": 50, | |
"reversescale": True, | |
"showscale": False, | |
"type": "histogram2dcontour", | |
"uid": "b20cd7", | |
"xbins": { | |
"end": max(x)*1.1, | |
"size": (max(x)-min(x))/nbins, | |
"start": min(x)-max(x)*0.1 | |
}, | |
"ybins": { | |
"end": max(y)*1.1, | |
"size": (max(y)-min(y))/nbins, | |
"start": min(y)-0.05 | |
}, | |
"zmax": 50, | |
"zmin": 0 | |
} | |
trace3a = { | |
"x": xa, | |
"marker": {"color": "rgb(31, 119, 180)"}, | |
"name": "Entropy histogram, NN", | |
"type": "histogram", | |
"showlegend": False, | |
"uid": "70efa7", | |
"xbins": { | |
"end": max(x), | |
"size": (max(x)-min(x))/nbins, | |
"start": min(x) | |
}, | |
"yaxis": "y2" | |
} | |
trace3b = { | |
"x": xb, | |
"marker": {"color": "red"}, | |
"name": "Entropy histogram, unbiased", | |
"type": "histogram", | |
"showlegend": False, | |
"uid": "70efa7", | |
"xbins": { | |
"end": max(x), | |
"size": (max(x)-min(x))/nbins, | |
"start": min(x) | |
}, | |
"yaxis": "y2" | |
} | |
trace4a = { | |
"y": ya, | |
"marker": {"color": "rgb(33, 113, 181)"}, | |
"name": "Error histogram, NN", | |
"type": "histogram", | |
"showlegend": False, | |
"uid": "73ca31", | |
"xaxis": "x2", | |
"ybins": { | |
"end": max(y), | |
"size": (max(y)-min(y))/nbins, | |
"start": min(y) | |
} | |
} | |
trace4b = { | |
"y": yb, | |
"marker": {"color": "red"}, | |
"name": "Error histogram, unbiased", | |
"type": "histogram", | |
"showlegend": False, | |
"uid": "73ca31", | |
"xaxis": "x2", | |
"ybins": { | |
"end": max(y), | |
"size": (max(y)-min(y))/nbins, | |
"start": min(y) | |
} | |
} | |
data = Data([trace1a, trace3a, trace4a, trace1b, trace2b, trace3b, trace4b]) | |
layout = { | |
"autosize": False, | |
"bargap": 0, | |
"height": 700, | |
"hovermode": "closest", | |
"margin": {"t": 70}, | |
# "paper_bgcolor": "rgb(249, 249, 249)", | |
# "plot_bgcolor": "rgb(249, 249, 249)", | |
"showlegend": False, | |
"font":dict(size=30, color='black'), | |
"legend": { | |
"x": 0.07, | |
"y": 0.8, | |
"bgcolor":'#E2E2E2', | |
"bordercolor":'#FFFFFF', | |
"borderwidth":2 | |
# font:dict( | |
# family='sans-serif', | |
# size=12, | |
# color='#000' | |
# ), | |
}, | |
"title": title, | |
"width": 800, | |
"xaxis": { | |
"autorange": True, | |
"domain": [0.05, 0.9], | |
"range": [min(x),max(x)], | |
"showgrid": False, | |
"title": xlabel, | |
"type": "linear", | |
"zeroline": False, | |
"showline":True, | |
"titlefont":dict(size=30, color='black') | |
}, | |
"xaxis2": { | |
"autorange": True, | |
"domain": [0.9, 1], | |
"range": [min(x),max(x)], | |
"showgrid": False, | |
"showticklabels":False, | |
"title": "", | |
"type": "linear", | |
"zeroline": True | |
}, | |
"yaxis": { | |
"autorange": True, | |
"domain": [0, 0.85], | |
"range": [min(y),max(y)], | |
"showgrid": False, | |
"title": ylabel, | |
"type": "linear", | |
"zeroline": False, | |
"showline":True, | |
"titlefont":dict(size=30, color='black') | |
}, | |
"yaxis2": { | |
"autorange": True, | |
"domain": [0.85, 1], | |
"range": [min(y),max(y)], | |
"showgrid": False, | |
"showticklabels":False, | |
"title": "", | |
"type": "linear", | |
"zeroline": True | |
} | |
} | |
fig = Figure(data=data, layout=layout) | |
# plot_url = py.plot(fig) | |
# py.iplot(fig, filename='Comp_gen_erro_hist') | |
py.image.save_as(fig, filename=filename) | |
import numpy as np | |
def shaded_std_plot(x1,y1s,filename,xlabel,ylabel,plotline=False): | |
# py.sign_in(username='guillefix', api_key='mflFpFhvUHtAfpGXbKkv') | |
# py.sign_in(username='guillefix3', api_key='ZxKRSVBk0GZnzfFsmdAj') | |
py.sign_in(username='guillefix4', api_key='Z9EakzS6cVPL0DN6SJWJ') | |
foo1 = list(zip(*sorted(list(zip(x1,y1s)),key=lambda x: x[0]))) | |
x1 = list(foo1[0]) | |
y1s = list(foo1[1]) | |
x1_rev = x1[::-1] | |
y1 = [np.mean(yy) for yy in y1s] | |
y1_upper = [y1[i]+np.std(yy) for i,yy in enumerate(y1s)] | |
y1_lower = [y1[i]-np.std(yy) for i,yy in enumerate(y1s)] | |
y1_lower = y1_lower[::-1] | |
# print(x) | |
# print(y) | |
trace1 = Scatter( | |
x=x1+x1_rev, | |
y=y1_upper+y1_lower, | |
fill='tozerox', | |
fillcolor='rgba(31, 119, 180,0.2)', | |
line=scatter.Line(color='rgba(31, 119, 180,0)'), | |
showlegend=False, | |
# name='Fair', | |
) | |
trace2 = Scatter( | |
x=x1, | |
y=y1, | |
line=Line(color='rgb(31, 119, 180)'), | |
mode='markers', | |
# name='Neural network', | |
) | |
trace3 = Scatter( | |
x = [0,1.05*max(x1)], | |
y = [0,1.05*max(y1)], | |
mode = 'lines', | |
name = 'lines' | |
) | |
if plotline: | |
data = Data([trace1, trace2, trace3]) | |
else: | |
data = Data([trace1, trace2]) | |
layout = Layout( | |
# legend=dict(x=0.75,y=0.08), | |
showlegend=False, | |
# paper_bgcolor='rgb(255,255,255)', | |
# plot_bgcolor='rgb(229,229,229)', | |
width=600, | |
height=500, | |
font=dict(size=22, color='black'), | |
xaxis=XAxis( | |
title=xlabel, | |
gridcolor='rgb(127,127,127)', | |
domain=[0.05,1], | |
dtick = 20, | |
range=[0.9*min(x1),1.05*max(x1)], | |
showgrid=False, | |
showline=True, | |
showticklabels=True, | |
tickcolor='rgb(127,127,127)', | |
ticks='outside', | |
zeroline=False | |
), | |
yaxis=YAxis( | |
title=ylabel, | |
gridcolor='rgb(127,127,127)', | |
showgrid=False, | |
# dtick = 20, | |
range=[0.9*min(y1),1.05*max(y1_upper)], | |
showline=True, | |
showticklabels=True, | |
tickcolor='rgb(127,127,127)', | |
ticks='outside', | |
zeroline=False | |
), | |
) | |
fig = Figure(data=data, layout=layout) | |
# py.iplot(fig, filename= 'shaded_lines') | |
py.image.save_as(fig, filename=filename) | |
def shaded_std_plot_double(x1,y1s,x2,y2s,filename,xlabel,ylabel): | |
# py.sign_in(username='guillefix', api_key='mflFpFhvUHtAfpGXbKkv') | |
# py.sign_in(username='guillefix3', api_key='ZxKRSVBk0GZnzfFsmdAj') | |
py.sign_in(username='guillefix4', api_key='Z9EakzS6cVPL0DN6SJWJ') | |
foo1 = list(zip(*sorted(list(zip(x1,y1s)),key=lambda x: x[0]))) | |
x1 = list(foo1[0]) | |
y1s = list(foo1[1]) | |
x1_rev = x1[::-1] | |
y1 = [np.mean(yy) for yy in y1s] | |
y1_upper = [y1[i]+np.std(yy) for i,yy in enumerate(y1s)] | |
y1_lower = [y1[i]-np.std(yy) for i,yy in enumerate(y1s)] | |
y1_lower = y1_lower[::-1] | |
foo2 = list(zip(*sorted(list(zip(x2,y2s)),key=lambda x: x[0]))) | |
x2 = list(foo2[0]) | |
y2s = list(foo2[1]) | |
x2_rev = x2[::-1] | |
y2 = [np.mean(yy) for yy in y2s] | |
y2_upper = [y2[i]+np.std(yy) for i,yy in enumerate(y2s)] | |
y2_lower = [y2[i]-np.std(yy) for i,yy in enumerate(y2s)] | |
y2_lower = y2_lower[::-1] | |
# print(x) | |
# print(y) | |
trace1 = Scatter( | |
x=x1+x1_rev, | |
y=y1_upper+y1_lower, | |
fill='tozerox', | |
fillcolor='rgba(31, 119, 180,0.2)', | |
line=Line(color='transparent'), | |
showlegend=False, | |
# name='Fair', | |
) | |
trace2 = Scatter( | |
x=x1, | |
y=y1, | |
line=Line(color='rgb(31, 119, 180)'), | |
mode='markers', | |
name='Neural network', | |
) | |
trace3 = Scatter( | |
x=x2+x2_rev, | |
y=y2_upper+y2_lower, | |
fill='tozerox', | |
fillcolor='rgba(255,0,0,0.2)', | |
line=Line(color='transparent'), | |
showlegend=False, | |
# name='Fair', | |
) | |
trace4 = Scatter( | |
x=x2, | |
y=y2, | |
line=Line(color='rgb(255,0,0)'), | |
mode='markers', | |
name='Unbiased learner', | |
) | |
data = Data([trace1, trace2, trace3, trace4]) | |
layout = Layout( | |
legend=dict(x=0.75,y=0.08), | |
# paper_bgcolor='rgb(255,255,255)', | |
# plot_bgcolor='rgb(229,229,229)', | |
font=dict(size=18, color='black'), | |
xaxis=XAxis( | |
title=xlabel, | |
gridcolor='rgb(127,127,127)', | |
dtick = 20, | |
range=[0.9*min(x1+x2),max(x1+x2)*1.1], | |
showgrid=False, | |
showline=True, | |
showticklabels=True, | |
tickcolor='rgb(127,127,127)', | |
ticks='outside', | |
zeroline=False | |
), | |
yaxis=YAxis( | |
title=ylabel, | |
gridcolor='rgb(127,127,127)', | |
showgrid=False, | |
showline=True, | |
showticklabels=True, | |
tickcolor='rgb(127,127,127)', | |
ticks='outside', | |
zeroline=False | |
), | |
) | |
fig = Figure(data=data, layout=layout) | |
# py.iplot(fig, filename= 'shaded_lines') | |
py.image.save_as(fig, filename=filename) | |
def shaded_std_plot_double_scatter(x1,y1s,x2,y2s,x3,y3,filename,xlabel,ylabel): | |
# py.sign_in(username='guillefix', api_key='mflFpFhvUHtAfpGXbKkv') | |
# py.sign_in(username='guillefix3', api_key='ZxKRSVBk0GZnzfFsmdAj') | |
py.sign_in(username='guillefix4', api_key='Z9EakzS6cVPL0DN6SJWJ') | |
# py.sign_in(username='guillefix5',api_key='lrZZKCPNIOwDtvgazKAL') | |
foo1 = list(zip(*sorted(list(zip(x1,y1s)),key=lambda x: x[0]))) | |
x1 = list(foo1[0]) | |
y1s = list(foo1[1]) | |
x1_rev = x1[::-1] | |
y1 = [np.mean(yy) for yy in y1s] | |
y1_upper = [y1[i]+np.std(yy) for i,yy in enumerate(y1s)] | |
y1_lower = [y1[i]-np.std(yy) for i,yy in enumerate(y1s)] | |
y1_lower = y1_lower[::-1] | |
foo2 = list(zip(*sorted(list(zip(x2,y2s)),key=lambda x: x[0]))) | |
x2 = list(foo2[0]) | |
y2s = list(foo2[1]) | |
x2_rev = x2[::-1] | |
y2 = [np.mean(yy) for yy in y2s] | |
y2_upper = [y2[i]+np.std(yy) for i,yy in enumerate(y2s)] | |
y2_lower = [y2[i]-np.std(yy) for i,yy in enumerate(y2s)] | |
y2_lower = y2_lower[::-1] | |
# print(x) | |
# print(y) | |
trace1 = Scatter( | |
x=x1+x1_rev, | |
y=y1_upper+y1_lower, | |
fill='tozerox', | |
fillcolor='rgba(31, 119, 180,0.2)', | |
line=Line(color='transparent'), | |
showlegend=False, | |
# name='Fair', | |
) | |
trace2 = Scatter( | |
x=x1, | |
y=y1, | |
line=Line(color='rgb(31, 119, 180)'), | |
mode='markers', | |
marker=dict(size=4,symbol='x'), | |
name='Neural network', | |
) | |
trace3 = Scatter( | |
x=x2+x2_rev, | |
y=y2_upper+y2_lower, | |
fill='tozerox', | |
fillcolor='rgba(255,0,0,0.2)', | |
line=Line(color='transparent'), | |
showlegend=False, | |
# name='Fair', | |
) | |
trace4 = Scatter( | |
x=x2, | |
y=y2, | |
line=Line(color='rgb(255,0,0)'), | |
mode='markers', | |
marker=dict(size=4,symbol='x'), | |
name='Unbiased learner', | |
) | |
trace5 = Scatter( | |
x=x3, | |
y=y3, | |
line=Line(color='rgb(0,0,0)'), | |
mode='markers', | |
name='Predicted bound', | |
) | |
data = Data([trace1, trace2, trace3, trace4, trace5]) | |
layout = Layout( | |
legend=dict(x=0.01,y=1), | |
# legend=dict(x=0.75,y=0.01), | |
# paper_bgcolor='rgb(255,255,255)', | |
# plot_bgcolor='rgb(229,229,229)', | |
# min(x1+x2) | |
font=dict(size=18, color='black'), | |
xaxis=XAxis( | |
title=xlabel, | |
gridcolor='rgb(127,127,127)', | |
range=[0,max(x1+x2)*1.1], | |
showgrid=False, | |
showline=True, | |
showticklabels=True, | |
tickcolor='rgb(127,127,127)', | |
ticks='outside', | |
zeroline=False | |
), | |
yaxis=YAxis( | |
title=ylabel, | |
gridcolor='rgb(127,127,127)', | |
range=[0,max(y1+y2)*1.3], | |
showgrid=False, | |
showline=True, | |
showticklabels=True, | |
tickcolor='rgb(127,127,127)', | |
ticks='outside', | |
zeroline=False | |
), | |
) | |
fig = Figure(data=data, layout=layout) | |
# py.iplot(fig, filename= 'shaded_lines') | |
py.image.save_as(fig, filename=filename) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment