Created
February 14, 2020 06:08
-
-
Save smutch/8b0146b47818c0d9cfe21002d9b78bd8 to your computer and use it in GitHub Desktop.
py: blackbox testing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
Testing of blackbox package. How well does it work and making sure I understand its use. | |
""" | |
import sys | |
sys.path.insert(0, "../3rd_party/blackbox") # noqa: E402 | |
import blackbox as bb | |
from scipy import stats | |
import pandas as pd | |
import matplotlib.pyplot as plt | |
import numpy as np | |
from joblib import Memory | |
# from snoop import pp | |
memory = Memory('cache') | |
def target(theta): | |
return -stats.norm.logpdf(theta[0]) | |
def target_nd(theta, mean=[1.0, 0.0], cov=np.eye(2)*0.1): | |
return -stats.multivariate_normal(mean, cov).logpdf(theta) | |
def _search_min(target=target, domain=[[-10.0, 10.0]], budget=40, batch=4, resfile="output.csv"): | |
bb.search_min( | |
target, # given function | |
domain=domain, # ranges of each parameter | |
budget=budget, # total number of function calls available | |
batch=batch, # number of calls that will be evaluated in parallel | |
resfile=resfile, | |
) # text file where results will be saved | |
res = pd.read_csv(resfile) | |
res.columns = [v.lstrip(' ').rstrip(' ') for v in res.columns] | |
return res | |
search_min = memory.cache(_search_min) | |
# 1D | |
resfile = "output.csv" | |
domain = [[-10.0, 10.0]] | |
res = search_min(resfile=resfile, domain=domain) | |
fig, ax = plt.subplots(1, 1, tight_layout=True) | |
theta = np.linspace(domain[0][0], domain[0][1], 100) | |
ax.plot(theta, target(theta[np.newaxis, :]), ls="--", color='0.8', lw=2, label='target', zorder=0) | |
res.plot.scatter("par_1", "f_value", ax=ax, c='', edgecolor='dodgerblue', s=20, label='trials') | |
res.loc[[0]].plot.scatter("par_1", "f_value", ax=ax, label='optimimum', s=30) | |
ax.legend(loc='upper center') | |
plt.tight_layout() | |
fig.savefig("results.pdf") | |
# 2D | |
resfile = "output.csv" | |
domain = [[-10.0, 10.0], [-10.0, 10.0]] | |
res = search_min(target=target_nd, resfile=resfile, domain=domain) | |
fig, ax = plt.subplots(1, 1, tight_layout=True) | |
theta = np.mgrid[domain[0][0]:domain[0][1]:100j, domain[1][0]:domain[1][1]:100j] | |
# ax.imshow(target_nd(np.dstack(theta)).T, origin='lower', extent=[v for vv in domain for v in vv]) | |
ax.contourf(theta[1], theta[0], target_nd(np.dstack(theta)).T, 20, origin='lower') | |
scat = ax.scatter([1.0], [0.0], marker='s', color='w', s=70, label='truth') | |
res.plot.scatter("par_1", "par_2", ax=ax, c='', edgecolor='r', s=20, label='trials') | |
res.loc[[0]].plot.scatter("par_1", "par_2", ax=ax, label='optimimum', c='firebrick', s=30) | |
ax.legend(loc='upper center') | |
plt.tight_layout() | |
fig.savefig("results_2d.pdf") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment