Skip to content

Instantly share code, notes, and snippets.

@smutch
Created February 14, 2020 06:08
Show Gist options
  • Save smutch/8b0146b47818c0d9cfe21002d9b78bd8 to your computer and use it in GitHub Desktop.
Save smutch/8b0146b47818c0d9cfe21002d9b78bd8 to your computer and use it in GitHub Desktop.
py: blackbox testing
"""
Testing of blackbox package. How well does it work and making sure I understand its use.
"""
import sys
sys.path.insert(0, "../3rd_party/blackbox") # noqa: E402
import blackbox as bb
from scipy import stats
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
from joblib import Memory
# from snoop import pp
memory = Memory('cache')
def target(theta):
return -stats.norm.logpdf(theta[0])
def target_nd(theta, mean=[1.0, 0.0], cov=np.eye(2)*0.1):
return -stats.multivariate_normal(mean, cov).logpdf(theta)
def _search_min(target=target, domain=[[-10.0, 10.0]], budget=40, batch=4, resfile="output.csv"):
bb.search_min(
target, # given function
domain=domain, # ranges of each parameter
budget=budget, # total number of function calls available
batch=batch, # number of calls that will be evaluated in parallel
resfile=resfile,
) # text file where results will be saved
res = pd.read_csv(resfile)
res.columns = [v.lstrip(' ').rstrip(' ') for v in res.columns]
return res
search_min = memory.cache(_search_min)
# 1D
resfile = "output.csv"
domain = [[-10.0, 10.0]]
res = search_min(resfile=resfile, domain=domain)
fig, ax = plt.subplots(1, 1, tight_layout=True)
theta = np.linspace(domain[0][0], domain[0][1], 100)
ax.plot(theta, target(theta[np.newaxis, :]), ls="--", color='0.8', lw=2, label='target', zorder=0)
res.plot.scatter("par_1", "f_value", ax=ax, c='', edgecolor='dodgerblue', s=20, label='trials')
res.loc[[0]].plot.scatter("par_1", "f_value", ax=ax, label='optimimum', s=30)
ax.legend(loc='upper center')
plt.tight_layout()
fig.savefig("results.pdf")
# 2D
resfile = "output.csv"
domain = [[-10.0, 10.0], [-10.0, 10.0]]
res = search_min(target=target_nd, resfile=resfile, domain=domain)
fig, ax = plt.subplots(1, 1, tight_layout=True)
theta = np.mgrid[domain[0][0]:domain[0][1]:100j, domain[1][0]:domain[1][1]:100j]
# ax.imshow(target_nd(np.dstack(theta)).T, origin='lower', extent=[v for vv in domain for v in vv])
ax.contourf(theta[1], theta[0], target_nd(np.dstack(theta)).T, 20, origin='lower')
scat = ax.scatter([1.0], [0.0], marker='s', color='w', s=70, label='truth')
res.plot.scatter("par_1", "par_2", ax=ax, c='', edgecolor='r', s=20, label='trials')
res.loc[[0]].plot.scatter("par_1", "par_2", ax=ax, label='optimimum', c='firebrick', s=30)
ax.legend(loc='upper center')
plt.tight_layout()
fig.savefig("results_2d.pdf")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment