Skip to content

Instantly share code, notes, and snippets.

@sneakers-the-rat
Created March 12, 2021 00:01
Show Gist options
  • Save sneakers-the-rat/32cfc3b8939d1e18aebc0ab6e03b72b0 to your computer and use it in GitHub Desktop.
Save sneakers-the-rat/32cfc3b8939d1e18aebc0ab6e03b72b0 to your computer and use it in GitHub Desktop.
grid_search_ssm.ipynb
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Parameter Search\n",
"\n",
"We can automatically find the optimal parameters for a model by running a shitload of models and comparing them. To do that all we need is a method for fitting the model and a method for scoring it, or evaluating how 'good' the model is. \n",
"\n",
"There's a general python package that's commonly used for HMMs in the sklearn API, https://github.com/hmmlearn/hmmlearn , but for compatibility we'll consider the ssm package https://github.com/lindermanlab/ssm\n",
"\n",
"We should make this fast by making sure we have OpenMP so we can do it in parallel:\n",
"\n",
"On Mac (also install llvm to get fopenmp) :\n",
"`brew install libomp llvm`\n",
"\n",
"Ubuntu\n",
"`sudo apt update && sudo apt install -y libomp-dev`\n",
"\n",
"and then install sm module with openmp enabled\n",
"`USE_OPENMP=True pip install -e .`"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Loading Data\n",
"\n",
"Copied and pasted from https://github.com/wehr-lab/Prey-Capture-SSM/blob/master/ssm_preycap_posterior.py"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"# loads training data from mat-file, saves posterior probabilities into ssm_posterior_probs.mat\n",
"\n",
"# builtins\n",
"import itertools\n",
"import time\n",
"import os\n",
"import sys\n",
"import pprint\n",
"import multiprocessing as mp\n",
"\n",
"# installed\n",
"import tqdm\n",
"import numpy as np\n",
"import ssm\n",
"from scipy import io, stats, signal\n",
"import pandas as pd\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"#load data using loadmat\n",
"mat=io.loadmat('training_data.mat') \n",
"X = mat['X']"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Data was (195663, 14) and now it's (32611, 14)\n"
]
}
],
"source": [
"big_shape = X.shape\n",
"# np.save('training_data.npy', X)\n",
"\n",
"# downsampling by factor of 6 for ~33Hz\n",
"# actually don't since it's already decimated, but just for the sake of\n",
"# showing how to do in python\n",
"#X = signal.decimate(X, 6, axis=0)\n",
"\n",
"print(f'Data was {big_shape} and now it\\'s {X.shape}')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Specfying Parameters\n",
"\n",
"In this case, we will do an exhaustive grid search it would take a bit of work to get the ssm objects to work with the sklearn api. (See their [documentation](https://scikit-learn.org/stable/modules/grid_search.html#randomized-parameter-optimization)_ for reasons why that would be a good idea)"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"num_states = list(range(5,20,2)) # K - number of discrete states\n",
"observations = ['autoregressive']\n",
"transitions = ['standard', 'sticky', 'recurrent']\n",
"\n",
"# since the hmm library has a nested api, \n",
"# have to create the arguments as lists of dictionaries\n",
"observation_kwargs = [{'lags':i} for i in range(1,11,5)]\n",
"\n",
"# static parameters\n",
"obs_dim = [X.shape[1]] # dimensionality of observation\n",
"\n",
"# --------------\n",
"# combine in a dictionary!\n",
"param_search = {\n",
" 'K': num_states,\n",
" 'D': obs_dim,\n",
" 'observations': observations,\n",
" 'transitions': transitions,\n",
" 'observation_kwargs': observation_kwargs\n",
"}\n",
"\n",
"# make a lil generator for iterating over params\n",
"def param_product(**kwargs):\n",
" \"\"\"https://stackoverflow.com/a/5228294/13113166\"\"\"\n",
" keys = kwargs.keys()\n",
" vals = kwargs.values()\n",
" for instance in itertools.product(*vals):\n",
" yield dict(zip(keys, instance))\n"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Will fit 48 models!\n",
"\n",
"[{'D': 14,\n",
" 'K': 5,\n",
" 'observation_kwargs': {'lags': 1},\n",
" 'observations': 'autoregressive',\n",
" 'transitions': 'standard'},\n",
" {'D': 14,\n",
" 'K': 5,\n",
" 'observation_kwargs': {'lags': 6},\n",
" 'observations': 'autoregressive',\n",
" 'transitions': 'standard'},\n",
" {'D': 14,\n",
" 'K': 5,\n",
" 'observation_kwargs': {'lags': 1},\n",
" 'observations': 'autoregressive',\n",
" 'transitions': 'sticky'}]\n"
]
}
],
"source": [
"permutations = list(param_product(**param_search))\n",
"\n",
"print(f'Will fit {len(permutations)} models!\\n')\n",
"\n",
"# print the first 3 as an example\n",
"pprint.pprint(permutations[0:3])"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Parallel Training & X-Val\n",
"\n",
"We'll use python's multiprocessing module to run models in parallel -- it proved to be too tricky to modify the HMM model to use the sklearn API. We'll also use the builtin model selection code in the ssm module to run a pool of workers"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"\n",
"def train_and_xval(input:tuple, n_repeats=3, hold_fraction=0.2):\n",
" \"\"\"\n",
" input is a tuple of data, params (bc pool only gives one argument)\n",
" \"\"\"\n",
" data, params = input\n",
" model = ssm.HMM(**params)\n",
" \n",
" n_holdout = np.floor(hold_fraction*data.shape[0]).astype(int)\n",
" \n",
" # print a blank link to get pbars to show up in notebook\n",
" print('', end='\\r')\n",
" \n",
" log_likelihoods = []\n",
" for run in range(n_repeats):\n",
" # split out a continuous region of data for x-val\n",
" # pick a random number from beginning to end-n_holdouts\n",
" censor_start = np.random.randint(0,data.shape[0]-n_holdout-1)\n",
" data_test = data[censor_start:censor_start+n_holdout,:]\n",
" data_train = np.delete(data, slice(censor_start,censor_start+n_holdout), axis=0)\n",
" \n",
" # fit model\n",
" model.fit(data_train)\n",
" \n",
" # return normalized log likelihood for test set\n",
" log_likelihoods.append(model.log_likelihood(data_test)/data_test.shape[0])\n",
" \n",
" \n",
" return log_likelihoods, params"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# time 2 train!!!\n",
"\n",
"use multiprocessing and train a bunch of models!!!\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
"def grid_search(data:np.ndarray, param_search:list, n_processes:int=12):\n",
" \"\"\"\n",
" Given data and a dictionary of parameters to search, \n",
" do a grid search with the SSM package.\n",
" \n",
" param_search is list of dictionaries, where each dictionary is a set of\n",
" model run parameters\n",
" \n",
" where the cartesian product of all the parameter values are evaluated\n",
" \"\"\"\n",
" \n",
" # create iterator that repeats the data for each item in the iterator\n",
" combined_iter = zip(itertools.cycle([data]), param_search)\n",
" \n",
" # store results \n",
" model_runs = []\n",
" \n",
" # create multiprocessing pool\n",
" with mp.Pool(n_processes) as pool:\n",
" \n",
" # start workers asynchronously\n",
" results = pool.imap_unordered(train_and_xval, combined_iter)\n",
" \n",
" # create progress bar\n",
" pbar = tqdm.tqdm(total=len(param_search))\n",
" \n",
" for r in results:\n",
" # unpack results a bit for easier dataframe creation later\n",
" log_likelihoods, params = r\n",
" params['log_likelihoods'] = log_likelihoods\n",
" model_runs.append(params)\n",
" pbar.update()\n",
" \n",
" # combine into pandas dataframe\n",
" return pd.DataFrame(model_runs)\n",
" \n",
" "
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"\r",
" 0%| | 0/12 [00:00<?, ?it/s]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": []
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "f165366848f84ab9b71fd1441d9e4b60",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(FloatProgress(value=0.0), HTML(value='')))"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "c87c15f7a86847f2bea018481449b0c0",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(FloatProgress(value=0.0), HTML(value='')))"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "9e8c43c8d48343689c2e0bb923c3ef8f",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(FloatProgress(value=0.0), HTML(value='')))"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "380e917ff02c4549b4443bea2b0fcb11",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(FloatProgress(value=0.0), HTML(value='')))"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "5fbfd2dcf88f46f0adff46d4b87b3d4f",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(FloatProgress(value=0.0), HTML(value='')))"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "b931875b07a64c439dfc65711e2a8d05",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(FloatProgress(value=0.0), HTML(value='')))"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "ec84b7afbcd1428bbc4fce6e0a224ba8",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(FloatProgress(value=0.0), HTML(value='')))"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "c96e23776d494420a27e1edd56f5755d",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(FloatProgress(value=0.0), HTML(value='')))"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "d83fdae09f3a4554a9f3b7c3c389baf4",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(FloatProgress(value=0.0), HTML(value='')))"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "5410a2b8e4e3425e8b57a45d3834a290",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(FloatProgress(value=0.0), HTML(value='')))"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "0a9948a37f2a4e5b9fc676b461bafcd9",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(FloatProgress(value=0.0), HTML(value='')))"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "5263134767fc4dc7a9acc7a13754cc7e",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(FloatProgress(value=0.0), HTML(value='')))"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "996a3c4d18924feeab3956e4bb66a73e",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(FloatProgress(value=0.0), HTML(value='')))"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "573000250a9549dbbaa10db6c3e87f72",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(FloatProgress(value=0.0), HTML(value='')))"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "0ecaed702ef84323af0e5e0cf362cc41",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(FloatProgress(value=0.0), HTML(value='')))"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "c82004e79f024f28944e2b65901dbf5f",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(FloatProgress(value=0.0), HTML(value='')))"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "bc40ade6b9b445fbb3dc5f294a7999a5",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(FloatProgress(value=0.0), HTML(value='')))"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "aae3b7f9027e4d60910f7846f35ec257",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(FloatProgress(value=0.0), HTML(value='')))"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "8390901f336e447e9f6db3d3629e6bc0",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(FloatProgress(value=0.0), HTML(value='')))"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "94f959efb76e4cf2b6452ac7ac29bc6f",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(FloatProgress(value=0.0), HTML(value='')))"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"\r",
" 8%|▊ | 1/12 [01:30<16:38, 90.82s/it]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"\r",
" 17%|█▋ | 2/12 [01:31<10:37, 63.72s/it]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "bea932db365c4bce8a12e97f1161a0d7",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(FloatProgress(value=0.0), HTML(value='')))"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"\r",
" 25%|██▌ | 3/12 [02:00<08:00, 53.39s/it]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"\r",
" 33%|███▎ | 4/12 [02:01<05:01, 37.68s/it]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "ae894eaefc6d4e6f886dc805d0803cc9",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(FloatProgress(value=0.0), HTML(value='')))"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "c1faacee5c1f4f22add2bc1e3bc12369",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(FloatProgress(value=0.0), HTML(value='')))"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "837c91d720ee42078223a321ceb760ea",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(FloatProgress(value=0.0), HTML(value='')))"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "44dccedd0b8f46229312895140476edf",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(FloatProgress(value=0.0), HTML(value='')))"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "3df807e6344c42ad9f887954c72382ab",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(FloatProgress(value=0.0), HTML(value='')))"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "71baa2df1c3b4045a16adeefa75160fb",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(FloatProgress(value=0.0), HTML(value='')))"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"\r",
" 42%|████▏ | 5/12 [03:59<07:11, 61.60s/it]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "2e12d07616d046378a97212208826db9",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(FloatProgress(value=0.0), HTML(value='')))"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "cd3467973a3e43cb91c96bda86792988",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(FloatProgress(value=0.0), HTML(value='')))"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "3632209ca30342319310cd2b55b3c66e",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(FloatProgress(value=0.0), HTML(value='')))"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"\r",
" 50%|█████ | 6/12 [05:34<07:10, 71.69s/it]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "acb66f649cfa4533aab961b516602ceb",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(FloatProgress(value=0.0), HTML(value='')))"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "f1eb4fc37d074473a2840f4d18522c65",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(FloatProgress(value=0.0), HTML(value='')))"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "b4722bd60b4a4961bee9d5d2d9becc5a",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(FloatProgress(value=0.0), HTML(value='')))"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "cbf1e2b4537b45f19c32c1f9809cbf07",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(FloatProgress(value=0.0), HTML(value='')))"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"\r",
" 58%|█████▊ | 7/12 [07:48<07:31, 90.36s/it]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"\r",
" 67%|██████▋ | 8/12 [07:50<04:15, 63.84s/it]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "516952ec8b0a4f5695c407877bd8872e",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(FloatProgress(value=0.0), HTML(value='')))"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "1f646ebe314c410c81832b41de227514",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(FloatProgress(value=0.0), HTML(value='')))"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"\r",
" 75%|███████▌ | 9/12 [09:09<03:25, 68.60s/it]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"\r",
" 83%|████████▎ | 10/12 [10:12<02:13, 66.91s/it]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"\r",
" 92%|█████████▏| 11/12 [10:22<00:49, 49.68s/it]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"\r",
"100%|██████████| 12/12 [11:26<00:00, 54.08s/it]"
]
}
],
"source": [
"params = param_product(**param_search)\n",
"params = list(params)[0:12]\n",
"\n",
"results = grid_search(X[0:10000,:], params, n_processes=12)"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>K</th>\n",
" <th>D</th>\n",
" <th>observations</th>\n",
" <th>transitions</th>\n",
" <th>observation_kwargs</th>\n",
" <th>log_likelihoods</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>5</td>\n",
" <td>14</td>\n",
" <td>autoregressive</td>\n",
" <td>sticky</td>\n",
" <td>{'lags': 1}</td>\n",
" <td>[73.63600490820458, 73.66281742297728, 71.8680...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>5</td>\n",
" <td>14</td>\n",
" <td>autoregressive</td>\n",
" <td>standard</td>\n",
" <td>{'lags': 1}</td>\n",
" <td>[73.64167739121682, 74.97694287706454, 72.0947...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>7</td>\n",
" <td>14</td>\n",
" <td>autoregressive</td>\n",
" <td>sticky</td>\n",
" <td>{'lags': 1}</td>\n",
" <td>[73.67920859553763, 70.55818373405485, 74.6496...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>7</td>\n",
" <td>14</td>\n",
" <td>autoregressive</td>\n",
" <td>standard</td>\n",
" <td>{'lags': 1}</td>\n",
" <td>[73.71081642856838, 70.63887450049349, 74.4475...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>5</td>\n",
" <td>14</td>\n",
" <td>autoregressive</td>\n",
" <td>recurrent</td>\n",
" <td>{'lags': 1}</td>\n",
" <td>[73.63214419000296, 73.6926339542364, 72.05606...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>5</th>\n",
" <td>7</td>\n",
" <td>14</td>\n",
" <td>autoregressive</td>\n",
" <td>recurrent</td>\n",
" <td>{'lags': 1}</td>\n",
" <td>[73.76921456088444, 70.81079060428786, 74.6286...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>6</th>\n",
" <td>5</td>\n",
" <td>14</td>\n",
" <td>autoregressive</td>\n",
" <td>standard</td>\n",
" <td>{'lags': 6}</td>\n",
" <td>[76.4593292441041, 77.44156131923047, 69.78935...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>7</th>\n",
" <td>5</td>\n",
" <td>14</td>\n",
" <td>autoregressive</td>\n",
" <td>sticky</td>\n",
" <td>{'lags': 6}</td>\n",
" <td>[77.16071708938951, 74.50652839566231, 67.1754...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>8</th>\n",
" <td>5</td>\n",
" <td>14</td>\n",
" <td>autoregressive</td>\n",
" <td>recurrent</td>\n",
" <td>{'lags': 6}</td>\n",
" <td>[77.14710032562544, 76.64457651074046, 70.9074...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>9</th>\n",
" <td>7</td>\n",
" <td>14</td>\n",
" <td>autoregressive</td>\n",
" <td>sticky</td>\n",
" <td>{'lags': 6}</td>\n",
" <td>[55.10422483160579, -4.791925793428352, 49.359...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>10</th>\n",
" <td>7</td>\n",
" <td>14</td>\n",
" <td>autoregressive</td>\n",
" <td>standard</td>\n",
" <td>{'lags': 6}</td>\n",
" <td>[57.016592852844255, -6.453637578048345, 48.95...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>11</th>\n",
" <td>7</td>\n",
" <td>14</td>\n",
" <td>autoregressive</td>\n",
" <td>recurrent</td>\n",
" <td>{'lags': 6}</td>\n",
" <td>[57.57047274312927, 22.043312475814187, 47.546...</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" K D observations transitions observation_kwargs \\\n",
"0 5 14 autoregressive sticky {'lags': 1} \n",
"1 5 14 autoregressive standard {'lags': 1} \n",
"2 7 14 autoregressive sticky {'lags': 1} \n",
"3 7 14 autoregressive standard {'lags': 1} \n",
"4 5 14 autoregressive recurrent {'lags': 1} \n",
"5 7 14 autoregressive recurrent {'lags': 1} \n",
"6 5 14 autoregressive standard {'lags': 6} \n",
"7 5 14 autoregressive sticky {'lags': 6} \n",
"8 5 14 autoregressive recurrent {'lags': 6} \n",
"9 7 14 autoregressive sticky {'lags': 6} \n",
"10 7 14 autoregressive standard {'lags': 6} \n",
"11 7 14 autoregressive recurrent {'lags': 6} \n",
"\n",
" log_likelihoods \n",
"0 [73.63600490820458, 73.66281742297728, 71.8680... \n",
"1 [73.64167739121682, 74.97694287706454, 72.0947... \n",
"2 [73.67920859553763, 70.55818373405485, 74.6496... \n",
"3 [73.71081642856838, 70.63887450049349, 74.4475... \n",
"4 [73.63214419000296, 73.6926339542364, 72.05606... \n",
"5 [73.76921456088444, 70.81079060428786, 74.6286... \n",
"6 [76.4593292441041, 77.44156131923047, 69.78935... \n",
"7 [77.16071708938951, 74.50652839566231, 67.1754... \n",
"8 [77.14710032562544, 76.64457651074046, 70.9074... \n",
"9 [55.10422483160579, -4.791925793428352, 49.359... \n",
"10 [57.016592852844255, -6.453637578048345, 48.95... \n",
"11 [57.57047274312927, 22.043312475814187, 47.546... "
]
},
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"results"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {},
"outputs": [],
"source": [
"results['mean_log_likelihood'] = [np.mean(vals) for vals in results.log_likelihoods.values]\n",
"results['sd_log_likelihood'] = [np.std(vals) for vals in results.log_likelihoods.values]"
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>K</th>\n",
" <th>D</th>\n",
" <th>observations</th>\n",
" <th>transitions</th>\n",
" <th>observation_kwargs</th>\n",
" <th>log_likelihoods</th>\n",
" <th>mean_log_likelihood</th>\n",
" <th>sd_log_likelihood</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>10</th>\n",
" <td>7</td>\n",
" <td>14</td>\n",
" <td>autoregressive</td>\n",
" <td>standard</td>\n",
" <td>{'lags': 6}</td>\n",
" <td>[57.016592852844255, -6.453637578048345, 48.95...</td>\n",
" <td>33.172303</td>\n",
" <td>28.212443</td>\n",
" </tr>\n",
" <tr>\n",
" <th>9</th>\n",
" <td>7</td>\n",
" <td>14</td>\n",
" <td>autoregressive</td>\n",
" <td>sticky</td>\n",
" <td>{'lags': 6}</td>\n",
" <td>[55.10422483160579, -4.791925793428352, 49.359...</td>\n",
" <td>33.223894</td>\n",
" <td>26.983362</td>\n",
" </tr>\n",
" <tr>\n",
" <th>11</th>\n",
" <td>7</td>\n",
" <td>14</td>\n",
" <td>autoregressive</td>\n",
" <td>recurrent</td>\n",
" <td>{'lags': 6}</td>\n",
" <td>[57.57047274312927, 22.043312475814187, 47.546...</td>\n",
" <td>42.386693</td>\n",
" <td>14.955732</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>7</td>\n",
" <td>14</td>\n",
" <td>autoregressive</td>\n",
" <td>standard</td>\n",
" <td>{'lags': 1}</td>\n",
" <td>[73.71081642856838, 70.63887450049349, 74.4475...</td>\n",
" <td>72.932398</td>\n",
" <td>1.649417</td>\n",
" </tr>\n",
" <tr>\n",
" <th>7</th>\n",
" <td>5</td>\n",
" <td>14</td>\n",
" <td>autoregressive</td>\n",
" <td>sticky</td>\n",
" <td>{'lags': 6}</td>\n",
" <td>[77.16071708938951, 74.50652839566231, 67.1754...</td>\n",
" <td>72.947558</td>\n",
" <td>4.222897</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>7</td>\n",
" <td>14</td>\n",
" <td>autoregressive</td>\n",
" <td>sticky</td>\n",
" <td>{'lags': 1}</td>\n",
" <td>[73.67920859553763, 70.55818373405485, 74.6496...</td>\n",
" <td>72.962356</td>\n",
" <td>1.745563</td>\n",
" </tr>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>5</td>\n",
" <td>14</td>\n",
" <td>autoregressive</td>\n",
" <td>sticky</td>\n",
" <td>{'lags': 1}</td>\n",
" <td>[73.63600490820458, 73.66281742297728, 71.8680...</td>\n",
" <td>73.055620</td>\n",
" <td>0.839819</td>\n",
" </tr>\n",
" <tr>\n",
" <th>5</th>\n",
" <td>7</td>\n",
" <td>14</td>\n",
" <td>autoregressive</td>\n",
" <td>recurrent</td>\n",
" <td>{'lags': 1}</td>\n",
" <td>[73.76921456088444, 70.81079060428786, 74.6286...</td>\n",
" <td>73.069553</td>\n",
" <td>1.635270</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>5</td>\n",
" <td>14</td>\n",
" <td>autoregressive</td>\n",
" <td>recurrent</td>\n",
" <td>{'lags': 1}</td>\n",
" <td>[73.63214419000296, 73.6926339542364, 72.05606...</td>\n",
" <td>73.126948</td>\n",
" <td>0.757630</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>5</td>\n",
" <td>14</td>\n",
" <td>autoregressive</td>\n",
" <td>standard</td>\n",
" <td>{'lags': 1}</td>\n",
" <td>[73.64167739121682, 74.97694287706454, 72.0947...</td>\n",
" <td>73.571137</td>\n",
" <td>1.177691</td>\n",
" </tr>\n",
" <tr>\n",
" <th>6</th>\n",
" <td>5</td>\n",
" <td>14</td>\n",
" <td>autoregressive</td>\n",
" <td>standard</td>\n",
" <td>{'lags': 6}</td>\n",
" <td>[76.4593292441041, 77.44156131923047, 69.78935...</td>\n",
" <td>74.563416</td>\n",
" <td>3.399502</td>\n",
" </tr>\n",
" <tr>\n",
" <th>8</th>\n",
" <td>5</td>\n",
" <td>14</td>\n",
" <td>autoregressive</td>\n",
" <td>recurrent</td>\n",
" <td>{'lags': 6}</td>\n",
" <td>[77.14710032562544, 76.64457651074046, 70.9074...</td>\n",
" <td>74.899693</td>\n",
" <td>2.830421</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" K D observations transitions observation_kwargs \\\n",
"10 7 14 autoregressive standard {'lags': 6} \n",
"9 7 14 autoregressive sticky {'lags': 6} \n",
"11 7 14 autoregressive recurrent {'lags': 6} \n",
"3 7 14 autoregressive standard {'lags': 1} \n",
"7 5 14 autoregressive sticky {'lags': 6} \n",
"2 7 14 autoregressive sticky {'lags': 1} \n",
"0 5 14 autoregressive sticky {'lags': 1} \n",
"5 7 14 autoregressive recurrent {'lags': 1} \n",
"4 5 14 autoregressive recurrent {'lags': 1} \n",
"1 5 14 autoregressive standard {'lags': 1} \n",
"6 5 14 autoregressive standard {'lags': 6} \n",
"8 5 14 autoregressive recurrent {'lags': 6} \n",
"\n",
" log_likelihoods mean_log_likelihood \\\n",
"10 [57.016592852844255, -6.453637578048345, 48.95... 33.172303 \n",
"9 [55.10422483160579, -4.791925793428352, 49.359... 33.223894 \n",
"11 [57.57047274312927, 22.043312475814187, 47.546... 42.386693 \n",
"3 [73.71081642856838, 70.63887450049349, 74.4475... 72.932398 \n",
"7 [77.16071708938951, 74.50652839566231, 67.1754... 72.947558 \n",
"2 [73.67920859553763, 70.55818373405485, 74.6496... 72.962356 \n",
"0 [73.63600490820458, 73.66281742297728, 71.8680... 73.055620 \n",
"5 [73.76921456088444, 70.81079060428786, 74.6286... 73.069553 \n",
"4 [73.63214419000296, 73.6926339542364, 72.05606... 73.126948 \n",
"1 [73.64167739121682, 74.97694287706454, 72.0947... 73.571137 \n",
"6 [76.4593292441041, 77.44156131923047, 69.78935... 74.563416 \n",
"8 [77.14710032562544, 76.64457651074046, 70.9074... 74.899693 \n",
"\n",
" sd_log_likelihood \n",
"10 28.212443 \n",
"9 26.983362 \n",
"11 14.955732 \n",
"3 1.649417 \n",
"7 4.222897 \n",
"2 1.745563 \n",
"0 0.839819 \n",
"5 1.635270 \n",
"4 0.757630 \n",
"1 1.177691 \n",
"6 3.399502 \n",
"8 2.830421 "
]
},
"execution_count": 17,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"results.sort_values(by=\"mean_log_likelihood\")"
]
},
{
"cell_type": "code",
"execution_count": 22,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"<AxesSubplot:xlabel='K', ylabel='mean_log_likelihood'>"
]
},
"execution_count": 22,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAX8AAAEGCAYAAACNaZVuAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuMCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy86wFpkAAAACXBIWXMAAAsTAAALEwEAmpwYAAAUyUlEQVR4nO3dfZAcdZ3H8fc3CYkxQQhJjEiIQaGw8AyRWylQRJRTEY/EE8WnEkQoyqrzqSwVlCqf/vAh5d35WNblsDywfECJCFqKIoh6ougiIaiIBAQJQggxQMJDWNjv/TG9sFk22e7NdM/u9vtVNc50z/TMJ0P7md7f9HRHZiJJapdpvQ4gSWqe5S9JLWT5S1ILWf6S1EKWvyS10IxeByhrwYIFuXTp0l7HkKRJ5eqrr747MxeOnD9pyn/p0qX09/f3OoYkTSoRceto8x32kaQWsvwlqYUsf0lqIctfklrI8pekFpry5b9523auve0eNm/b3usokjRhTJpdPcfjorW3c+aadewxbRoDg4OsOnEZK5bv1+tYktRzU3bLf/O27Zy5Zh0PDQyydfsjPDQwyAfWrPMvAEliCpf/hi0PVpovSW0yZct/zszpPDQwuMO8hwYGmTNzeo8SSdLEMWXL//6HH2XW9Nhh3qzpwf0PP9qjRJI0cUzZ8l88bzYxbcfyj2nB4nmze5RIkiaOKVv+8+fOYtWJy5g1YxpPnjmdWTOmserEZcyfO6vX0SSp56Zs+QPk0P/m41OSpClc/kO7em5/JHlg4FG2P5Lu6ilpUlm/cSsX9N/G+o1bu/7cU/ZHXhu2PMj2EXv7bB8YZMOWBx36kTThffh713Heb/722PTJRy7h4yuf27Xnn7Jb/ndvfegJAz1ZzJekiWz9xq07FD/Aeb/+W1f/Apiy5X/thnsrzZekieL/1t9daf54TNnyP/qgBZXmS9JEsWDuzErzx2PKln/fAfN50YHzd5j3ogPn03fA/J0sIUkTw5HPWsCInykxLTrzu2XKfuEL8LXTj6D/r5v5xY13c/RBCyx+SZPC/Lmz+Ozrl/O+71xLRJCZfOZ1h3Z1Z5XInBz7v/f19WV/f3+vY0hSYzZv286GLQ+yeN7scRd/RFydmX0j50/pLX9Jmszmz51V267pU3bMX5K0c5a/JLWQ5S9JLWT5S1ILWf6S1EK1ln9EHBwRa4dd7ouI90TEPhFxaUTcWFzPqzOHJGlHtZZ/Zt6Qmcszcznwz8ADwIXAWcBlmXkQcFkxLUlqSJPDPscCN2XmrcBK4Nxi/rnAqxvMIUmt12T5vwH4ZnF7UWbeUdy+E1g02gIRcUZE9EdE/6ZNm5rIKEmt0Ej5R8RMYAXwnZH3Zef4EqMeYyIzV2dmX2b2LVy4sOaUktQeTW35vxL4fWZuLKY3RsS+AMX1XQ3lkCTRXPm/kceHfAAuBk4pbp8CXNRQDkkSDZR/RMwBXgZ8d9jsTwEvi4gbgX8ppiVJDan9qJ6ZeT8wf8S8zXT2/pEk9YC/8JWkFrL8JamFLH9JaiHLX5JayPKXpBay/CWphSx/SWohy1+SWsjyl6QWsvwlqYUsf0lqIctfklrI8pekFrL8JamFLH9JaiHLX5JayPKXpBay/CWphSx/SWohy1+SWsjyl6QWsvwlqYUsf0lqoRljPSAiXrOr+zPzu92LI0lqwpjlD5xQXD8VeAFweTH9EuBKwPKXpElmzPLPzFMBIuInwCGZeUcxvS/wv7WmkyTVosqY//5DxV/YCCzpch5JUgPKDPsMuSwifgx8s5h+PfDT7keSJNWtdPln5jsi4t+Ao4tZqzPzwnpiSZLqVGXLHzpf8D4CJPDb7seRJDWh9Jh/RJxEp/BfC5wEXBURr60rmCSpPlW2/M8Gnp+ZdwFExEI6Y/4X1BFMklSfKnv7TBsq/sLmistLkiaIKlv+l4yyt88Pux9JklS3Knv7vL841MNRxSz39pGkSarq3j6/AgZwbx9JmtTc20eSWsi9fSSphdzbR5JaqPa9fSJib+Ac4J/ofFfwNuAG4HxgKXALcFJmbqmQRZK0G0pvuWfm+4HVwLLisjozzyyx6OeASzLz2cChwPXAWcBlmXkQcFkxLUlqSKW9fTJzDbCm7OMjYi86B4J7a7H8w8DDEbESOKZ42LnAFUCZDxJJUhdU2dvnNRFxY0TcGxH3RcTWiLhvjMUOADYBX42IayLinIiYAywadm6AO4FFO3nNMyKiPyL6N23aVDaqJGkMVb6wXQWsyMy9MvMpmblnZj5ljGVmAIcBX87M5wH3M2KIJzOTzncBT5CZqzOzLzP7Fi5cWCGqJGlXqpT/xsy8vuLzbwA2ZOZVxfQFdD4MNhangRw6HeRdO1leklSDMcf8i0M6APRHxPnA94DtQ/dn5k5P4J6Zd0bEbRFxcGbeABwL/Km4nAJ8qri+aNz/AklSZWW+8D1h2O0HgJcPm05gp+VfeCfw9YiYCdwMnErnL45vR8RpwK10fjEsSWrImOWfmafuzgtk5lqgb5S7jt2d55UkjV+ZYZ8PZOaqiPgCo3wxm5nvqiWZJKk2ZYZ9hr7k7a8ziCSpOWWGfb5fXJ9bfxxJUhPKDPt8n53shw+QmSu6mkiSVLsywz6fqT2FJKlRZYZ9fj50OyJmA0uKffYlSZNUlWP7nACsBS4pppdHxMU15ZIk1ajK4R0+ChwO3AOP7b9/QNcTSZJqV6X8BzLz3hHzdvpFsCRp4qpyPP8/RsSbgOkRcRDwLuDKemJJkupUZcv/ncBz6BzU7RvAfcC76wglSapXlfJ/Y2aenZnPLy5nAx+rK5gkqT5Vhn1OjIiHMvPrABHxRWB2PbEkSXWqVP7AxRExCBwH3JOZp9UTS5JUpzKHd9hn2OTpdE7m8ivgYxGxT2b+o6ZskqSalNnyv5rOLp0x7PpVxSWBZ9aWTpJUizKHd/CHXJI0xZQZ9nlpZl4+7Fy+O9jVOXwlSRNTmWGfFwOXs+O5fIeUOYevJGmCKTPs85HierfO5StJmjjKDPu8d1f3Z+Z/di+OJKkJZYZ99qw9hSSpUWWGfUodwiEiPpiZn9z9SJKkulU5ts9YXtfF55Ik1aib5R9dfC5JUo26Wf6e2EWSJgm3/CWphbpZ/t/p4nNJkmpU+pDOEfH5UWbfC/Rn5kWZ+YnuxZIk1anKlv+TgOXAjcVlGbAYOC0iPtv1ZJKk2lQ5mcsy4IWZ+ShARHwZ+CVwFHBdDdkkSTWpsuU/D5g7bHoOsE/xYbC9q6kkSbWqsuW/ClgbEVfQ2bPnaOATETEH+GkN2SRJNSld/pn5lYj4IXB4MetDmfn34vb7u55MklSbKlv+AM8HXlTcHgT+vovHSpImqNJj/hHxKeDdwJ+Ky7siwt07JWkSqrLlfzywPDMHASLiXOAa4EN1BJMk1afqL3z3HnZ7ry7mkCQ1qMqW/yeBayLiZzy+t89ZYy0UEbcAW4FHgUcysy8i9gHOB5YCtwAnZeaWSsklSeNWess/M78JHEHnhO1rgCMz8/ySi78kM5dnZl8xfRZwWWYeBFxGiQ8RSVL3lDmH72EjZm0orp8eEU/PzN+P43VXAscUt88FrgDOHMfzSJLGocywz3/s4r4EXjrG8gn8JCIS+O/MXA0sysw7ivvvBBaNtmBEnAGcAbBkyZISUSVJZZQ5h+9LyjxRRLwsMy8d5a6jMvP2iHgqcGlE/HnE82fxwTDaa68GVgP09fV5shhJ6pJuHs//06PNzMzbi+u7gAvp/EJ4Y0TsC1Bc39XFHJKkMdR6Jq+ImBMRew7dBl4O/AG4GDileNgpwEVdzCFJGkPVwzvsymjDMouACyNi6LW+kZmXRMTvgG9HxGnArcBJXcwhSRpDN8v/CTLzZuDQUeZvBo6t87UlSTvXzWGfW7r4XJKkGlXa8o+IF9D5Ve5jy2XmecX1a7qaTJJUmyoncP8a8CxgLZ1DNUBnnP+87seSJNWpypZ/H3BIZrq/vSRNclXG/P8APK2uIJKk5lTZ8l8A/CkifsuwE7Zn5oqup5Ik1apK+X+0rhCSpGZVOYH7z+sMIklqTpVz+B4REb+LiG0R8XBEPBoR99UZTpJUjypf+H4ReCNwIzAbOB34Uh2hJEn1qvQL38xcD0zPzEcz86vAcfXEkiTVqcoXvg9ExExgbUSsAu6gu4eHkCQ1pEp5v6V4/DuA+4H9gRPrCCVJqleVvX1ujYjZwL6Z+bEaM0mSalZlb58T6BzX55JienlEXFxTLklSjaoM+3yUzikY7wHIzLXAAV1PJEmqXZXyH8jMe0fM8yBvkjQJVdnb548R8SZgekQcBLwLuLKeWJKkOlXZ8n8n8Bw6B3X7BnAv8O46QkmS6lWl/A8pLjOAJwErgd/VEUqSVK8qwz5fB95H57j+g/XEkSQ1oUr5b8rM79eWRJLUmCrl/5GIOAe4jB1P5vLdrqeSJNWqSvmfCjwb2IPHh30SsPwlaZKpUv7Pz8yDa0siSWpMlb19royIQ2pLIklqTJUt/yPoHM75r3TG/APIzFxWSzJJUm2qlL8nbpGkBm3etp0NWx5k8bzZzJ87q6vPXemQzl19ZUnSTl209nbOXLOOPaZNY2BwkFUnLmPF8v269vyeiUuSJpjN27Zz5pp1PDQwyNbtj/DQwCAfWLOOzdu2j71wSZa/JE0wG7Y8yB7TdqznPaZNY8OWB7v2Gpa/JE0wi+fNZmBwx6PoDAwOsnje7K69huUvSRPM/LmzWHXiMp60xzT2nDWDJ+0xjVUnLuvql75V9vaRJDVkxfL9eOGBC3q/t48kqVnz587qeukPcdhHklrI8pekFrL8JamFGin/iJgeEddExA+K6QMi4qqIWB8R50fEzCZySJI6mtryfzdw/bDpTwP/lZkHAluA0xrKIUmigfKPiMXAq4BziukAXgpcUDzkXODVdeeQJD2uiS3/zwIf4PGzf80H7snMR4rpDUD3jlYkSRpTreUfEf8K3JWZV49z+TMioj8i+jdt2tTldJLUXnVv+b8QWBERtwDfojPc8zlg74gY+oHZYuD20RbOzNWZ2ZeZfQsXLqw5qiS1R63ln5kfzMzFmbkUeANweWa+GfgZ8NriYacAF9WZQ5K0o17t538m8N6IWE/nO4Cv9CiHJLVSY8f2ycwrgCuK2zcDhzf12pKkHfkLX0lqIctfklrI8pekFrL8JamFLH9JaiHLX5JayPKXpBay/CWphSx/SWohy1+SWsjyl6QWsvwlqYUsf0lqIctfklrI8pekFrL8JamFLH9JaiHLX5JayPKXpBay/CWphSx/SWohy1+SWsjyl6QWsvwlqYUsf0maoNZv3MoF/bexfuPWrj/3jK4/oyRpt334e9dx3m/+9tj0yUcu4eMrn9u153fLX5ImmPUbt+5Q/ADn/fpvXf0LwPKXpAlm7W33VJo/Hpa/JE0wy/ffu9L88bD8JWmCOXDRnpx85JId5p185BIOXLRn117DL3wlaQL6+MrncvIRS1l72z0s33/vrhY/WP6SNGEduGjPrpf+EId9JKmFLH9JaiHLX5JayPKXpBay/CWphSIze52hlIjYBNw6zsUXAHd3MU63mKsac1Vjrmqmaq5nZObCkTMnTfnvjojoz8y+XucYyVzVmKsac1XTtlwO+0hSC1n+ktRCbSn/1b0OsBPmqsZc1ZirmlblasWYvyRpR23Z8pckDWP5S1ILTfryj4hbIuK6iFgbEf2j3B8R8fmIWB8R6yLisGH3nRIRNxaXUxrO9eYiz3URcWVEHFp22ZpzHRMR9xb3r42IDw+777iIuKF4L89qONf7h2X6Q0Q8GhH7lFl2N3PtHREXRMSfI+L6iDhyxP29Wr/GytWr9WusXL1av8bK1fj6FREHD3vNtRFxX0S8Z8Rj6lu/MnNSX4BbgAW7uP944EdAAEcAVxXz9wFuLq7nFbfnNZjrBUOvB7xyKFeZZWvOdQzwg1HmTwduAp4JzASuBQ5pKteIx54AXN7Q+3UucHpxeyaw9wRZv8bK1av1a6xcvVq/dpmrV+vXiH//nXR+kNXI+jXpt/xLWAmclx2/AfaOiH2BVwCXZuY/MnMLcClwXFOhMvPK4nUBfgMsbuq1x+lwYH1m3pyZDwPfovPe9sIbgW/W/SIRsRdwNPAVgMx8ODPvGfGwxtevMrl6sX6VfL92prb1axy5Glm/RjgWuCkzRx7FoLb1ayqUfwI/iYirI+KMUe7fD7ht2PSGYt7O5jeVa7jT6Hy6j2fZOnIdGRHXRsSPIuI5xbwJ8X5FxJPprORrqi47DgcAm4CvRsQ1EXFORMwZ8ZherF9lcg3X1PpVNlfT61fp96vh9Wu4NzD6B05t69dUKP+jMvMwOn/a/ntEHN3rQIVSuSLiJXT+z3lm1WVryvV7On96Hgp8AfheF197d3INOQH4VWb+YxzLVjUDOAz4cmY+D7gf6OpY9DiVztXw+lUmVy/Wryr/HZtcvwCIiJnACuA73XzesUz68s/M24vru4AL6fz5ONztwP7DphcX83Y2v6lcRMQy4BxgZWZurrJsXbky877M3Fbc/iGwR0QsYAK8X4UnbCHV+H5tADZk5lXF9AV0SmS4XqxfZXL1Yv0aM1eP1q9S71ehyfVryCuB32fmxlHuq239mtTlHxFzImLPodvAy4E/jHjYxcDJxbfmRwD3ZuYdwI+Bl0fEvIiYVyz746ZyRcQS4LvAWzLzLxX/TXXmelpERHH7cDrryGbgd8BBEXFAsaXyBjrvbSO5ivv2Al4MXFR12fHIzDuB2yLi4GLWscCfRjys8fWrTK5erF8lczW+fpX879j4+jXMrr5jqG/92t1vqXt5obNnwLXF5Y/A2cX8twNvL24H8CU6exJcB/QNW/5twPricmrDuc4BtgBri0v/rpZtMNc7ivuupfNF4QuGLX888JfivWw0VzH9VuBbZZbtYrblQD+wjs4Qxbxer18lczW+fpXM1fj6VSZXD9evOXQ+/PYaNq+R9cvDO0hSC03qYR9J0vhY/pLUQpa/JLWQ5S9JLWT5S1ILWf7SOEXEtmG3j4+Iv0TEM3qZSSprRq8DSJNdRBwLfB54RT7xwFzShGT5S7uhOM7L/wDHZ+ZNvc4jleWPvKRxiogBYCtwTGau63UeqQrH/KXxGwCupHPUTGlSsfyl8RsETgIOj4gP9TqMVIVj/tJuyMwHIuJVwC8jYmNmfqXXmaQyLH9pN2XmPyLiOOAXEbEpM7tyKGKpTn7hK0kt5Ji/JLWQ5S9JLWT5S1ILWf6S1EKWvyS1kOUvSS1k+UtSC/0/JImDh+CkV0IAAAAASUVORK5CYII=\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"results.plot(x='K', y='mean_log_likelihood',kind='scatter')"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.7"
}
},
"nbformat": 4,
"nbformat_minor": 4
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment