Created
March 12, 2021 00:01
-
-
Save sneakers-the-rat/32cfc3b8939d1e18aebc0ab6e03b72b0 to your computer and use it in GitHub Desktop.
grid_search_ssm.ipynb
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"# Parameter Search\n", | |
"\n", | |
"We can automatically find the optimal parameters for a model by running a shitload of models and comparing them. To do that all we need is a method for fitting the model and a method for scoring it, or evaluating how 'good' the model is. \n", | |
"\n", | |
"There's a general python package that's commonly used for HMMs in the sklearn API, https://github.com/hmmlearn/hmmlearn , but for compatibility we'll consider the ssm package https://github.com/lindermanlab/ssm\n", | |
"\n", | |
"We should make this fast by making sure we have OpenMP so we can do it in parallel:\n", | |
"\n", | |
"On Mac (also install llvm to get fopenmp) :\n", | |
"`brew install libomp llvm`\n", | |
"\n", | |
"Ubuntu\n", | |
"`sudo apt update && sudo apt install -y libomp-dev`\n", | |
"\n", | |
"and then install sm module with openmp enabled\n", | |
"`USE_OPENMP=True pip install -e .`" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"# Loading Data\n", | |
"\n", | |
"Copied and pasted from https://github.com/wehr-lab/Prey-Capture-SSM/blob/master/ssm_preycap_posterior.py" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"# loads training data from mat-file, saves posterior probabilities into ssm_posterior_probs.mat\n", | |
"\n", | |
"# builtins\n", | |
"import itertools\n", | |
"import time\n", | |
"import os\n", | |
"import sys\n", | |
"import pprint\n", | |
"import multiprocessing as mp\n", | |
"\n", | |
"# installed\n", | |
"import tqdm\n", | |
"import numpy as np\n", | |
"import ssm\n", | |
"from scipy import io, stats, signal\n", | |
"import pandas as pd\n", | |
"\n" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 2, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"#load data using loadmat\n", | |
"mat=io.loadmat('training_data.mat') \n", | |
"X = mat['X']" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 3, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Data was (195663, 14) and now it's (32611, 14)\n" | |
] | |
} | |
], | |
"source": [ | |
"big_shape = X.shape\n", | |
"# np.save('training_data.npy', X)\n", | |
"\n", | |
"# downsampling by factor of 6 for ~33Hz\n", | |
"# actually don't since it's already decimated, but just for the sake of\n", | |
"# showing how to do in python\n", | |
"#X = signal.decimate(X, 6, axis=0)\n", | |
"\n", | |
"print(f'Data was {big_shape} and now it\\'s {X.shape}')" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"# Specfying Parameters\n", | |
"\n", | |
"In this case, we will do an exhaustive grid search it would take a bit of work to get the ssm objects to work with the sklearn api. (See their [documentation](https://scikit-learn.org/stable/modules/grid_search.html#randomized-parameter-optimization)_ for reasons why that would be a good idea)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 4, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"num_states = list(range(5,20,2)) # K - number of discrete states\n", | |
"observations = ['autoregressive']\n", | |
"transitions = ['standard', 'sticky', 'recurrent']\n", | |
"\n", | |
"# since the hmm library has a nested api, \n", | |
"# have to create the arguments as lists of dictionaries\n", | |
"observation_kwargs = [{'lags':i} for i in range(1,11,5)]\n", | |
"\n", | |
"# static parameters\n", | |
"obs_dim = [X.shape[1]] # dimensionality of observation\n", | |
"\n", | |
"# --------------\n", | |
"# combine in a dictionary!\n", | |
"param_search = {\n", | |
" 'K': num_states,\n", | |
" 'D': obs_dim,\n", | |
" 'observations': observations,\n", | |
" 'transitions': transitions,\n", | |
" 'observation_kwargs': observation_kwargs\n", | |
"}\n", | |
"\n", | |
"# make a lil generator for iterating over params\n", | |
"def param_product(**kwargs):\n", | |
" \"\"\"https://stackoverflow.com/a/5228294/13113166\"\"\"\n", | |
" keys = kwargs.keys()\n", | |
" vals = kwargs.values()\n", | |
" for instance in itertools.product(*vals):\n", | |
" yield dict(zip(keys, instance))\n" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 5, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Will fit 48 models!\n", | |
"\n", | |
"[{'D': 14,\n", | |
" 'K': 5,\n", | |
" 'observation_kwargs': {'lags': 1},\n", | |
" 'observations': 'autoregressive',\n", | |
" 'transitions': 'standard'},\n", | |
" {'D': 14,\n", | |
" 'K': 5,\n", | |
" 'observation_kwargs': {'lags': 6},\n", | |
" 'observations': 'autoregressive',\n", | |
" 'transitions': 'standard'},\n", | |
" {'D': 14,\n", | |
" 'K': 5,\n", | |
" 'observation_kwargs': {'lags': 1},\n", | |
" 'observations': 'autoregressive',\n", | |
" 'transitions': 'sticky'}]\n" | |
] | |
} | |
], | |
"source": [ | |
"permutations = list(param_product(**param_search))\n", | |
"\n", | |
"print(f'Will fit {len(permutations)} models!\\n')\n", | |
"\n", | |
"# print the first 3 as an example\n", | |
"pprint.pprint(permutations[0:3])" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"# Parallel Training & X-Val\n", | |
"\n", | |
"We'll use python's multiprocessing module to run models in parallel -- it proved to be too tricky to modify the HMM model to use the sklearn API. We'll also use the builtin model selection code in the ssm module to run a pool of workers" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 6, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"\n", | |
"def train_and_xval(input:tuple, n_repeats=3, hold_fraction=0.2):\n", | |
" \"\"\"\n", | |
" input is a tuple of data, params (bc pool only gives one argument)\n", | |
" \"\"\"\n", | |
" data, params = input\n", | |
" model = ssm.HMM(**params)\n", | |
" \n", | |
" n_holdout = np.floor(hold_fraction*data.shape[0]).astype(int)\n", | |
" \n", | |
" # print a blank link to get pbars to show up in notebook\n", | |
" print('', end='\\r')\n", | |
" \n", | |
" log_likelihoods = []\n", | |
" for run in range(n_repeats):\n", | |
" # split out a continuous region of data for x-val\n", | |
" # pick a random number from beginning to end-n_holdouts\n", | |
" censor_start = np.random.randint(0,data.shape[0]-n_holdout-1)\n", | |
" data_test = data[censor_start:censor_start+n_holdout,:]\n", | |
" data_train = np.delete(data, slice(censor_start,censor_start+n_holdout), axis=0)\n", | |
" \n", | |
" # fit model\n", | |
" model.fit(data_train)\n", | |
" \n", | |
" # return normalized log likelihood for test set\n", | |
" log_likelihoods.append(model.log_likelihood(data_test)/data_test.shape[0])\n", | |
" \n", | |
" \n", | |
" return log_likelihoods, params" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"# time 2 train!!!\n", | |
"\n", | |
"use multiprocessing and train a bunch of models!!!\n", | |
"\n" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 7, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"def grid_search(data:np.ndarray, param_search:list, n_processes:int=12):\n", | |
" \"\"\"\n", | |
" Given data and a dictionary of parameters to search, \n", | |
" do a grid search with the SSM package.\n", | |
" \n", | |
" param_search is list of dictionaries, where each dictionary is a set of\n", | |
" model run parameters\n", | |
" \n", | |
" where the cartesian product of all the parameter values are evaluated\n", | |
" \"\"\"\n", | |
" \n", | |
" # create iterator that repeats the data for each item in the iterator\n", | |
" combined_iter = zip(itertools.cycle([data]), param_search)\n", | |
" \n", | |
" # store results \n", | |
" model_runs = []\n", | |
" \n", | |
" # create multiprocessing pool\n", | |
" with mp.Pool(n_processes) as pool:\n", | |
" \n", | |
" # start workers asynchronously\n", | |
" results = pool.imap_unordered(train_and_xval, combined_iter)\n", | |
" \n", | |
" # create progress bar\n", | |
" pbar = tqdm.tqdm(total=len(param_search))\n", | |
" \n", | |
" for r in results:\n", | |
" # unpack results a bit for easier dataframe creation later\n", | |
" log_likelihoods, params = r\n", | |
" params['log_likelihoods'] = log_likelihoods\n", | |
" model_runs.append(params)\n", | |
" pbar.update()\n", | |
" \n", | |
" # combine into pandas dataframe\n", | |
" return pd.DataFrame(model_runs)\n", | |
" \n", | |
" " | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 8, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"\r", | |
" 0%| | 0/12 [00:00<?, ?it/s]" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [] | |
}, | |
{ | |
"data": { | |
"application/vnd.jupyter.widget-view+json": { | |
"model_id": "f165366848f84ab9b71fd1441d9e4b60", | |
"version_major": 2, | |
"version_minor": 0 | |
}, | |
"text/plain": [ | |
"HBox(children=(FloatProgress(value=0.0), HTML(value='')))" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
}, | |
{ | |
"data": { | |
"application/vnd.jupyter.widget-view+json": { | |
"model_id": "c87c15f7a86847f2bea018481449b0c0", | |
"version_major": 2, | |
"version_minor": 0 | |
}, | |
"text/plain": [ | |
"HBox(children=(FloatProgress(value=0.0), HTML(value='')))" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
}, | |
{ | |
"data": { | |
"application/vnd.jupyter.widget-view+json": { | |
"model_id": "9e8c43c8d48343689c2e0bb923c3ef8f", | |
"version_major": 2, | |
"version_minor": 0 | |
}, | |
"text/plain": [ | |
"HBox(children=(FloatProgress(value=0.0), HTML(value='')))" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
}, | |
{ | |
"data": { | |
"application/vnd.jupyter.widget-view+json": { | |
"model_id": "380e917ff02c4549b4443bea2b0fcb11", | |
"version_major": 2, | |
"version_minor": 0 | |
}, | |
"text/plain": [ | |
"HBox(children=(FloatProgress(value=0.0), HTML(value='')))" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
}, | |
{ | |
"data": { | |
"application/vnd.jupyter.widget-view+json": { | |
"model_id": "5fbfd2dcf88f46f0adff46d4b87b3d4f", | |
"version_major": 2, | |
"version_minor": 0 | |
}, | |
"text/plain": [ | |
"HBox(children=(FloatProgress(value=0.0), HTML(value='')))" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
}, | |
{ | |
"data": { | |
"application/vnd.jupyter.widget-view+json": { | |
"model_id": "b931875b07a64c439dfc65711e2a8d05", | |
"version_major": 2, | |
"version_minor": 0 | |
}, | |
"text/plain": [ | |
"HBox(children=(FloatProgress(value=0.0), HTML(value='')))" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
}, | |
{ | |
"data": { | |
"application/vnd.jupyter.widget-view+json": { | |
"model_id": "ec84b7afbcd1428bbc4fce6e0a224ba8", | |
"version_major": 2, | |
"version_minor": 0 | |
}, | |
"text/plain": [ | |
"HBox(children=(FloatProgress(value=0.0), HTML(value='')))" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
}, | |
{ | |
"data": { | |
"application/vnd.jupyter.widget-view+json": { | |
"model_id": "c96e23776d494420a27e1edd56f5755d", | |
"version_major": 2, | |
"version_minor": 0 | |
}, | |
"text/plain": [ | |
"HBox(children=(FloatProgress(value=0.0), HTML(value='')))" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
}, | |
{ | |
"data": { | |
"application/vnd.jupyter.widget-view+json": { | |
"model_id": "d83fdae09f3a4554a9f3b7c3c389baf4", | |
"version_major": 2, | |
"version_minor": 0 | |
}, | |
"text/plain": [ | |
"HBox(children=(FloatProgress(value=0.0), HTML(value='')))" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
}, | |
{ | |
"data": { | |
"application/vnd.jupyter.widget-view+json": { | |
"model_id": "5410a2b8e4e3425e8b57a45d3834a290", | |
"version_major": 2, | |
"version_minor": 0 | |
}, | |
"text/plain": [ | |
"HBox(children=(FloatProgress(value=0.0), HTML(value='')))" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
}, | |
{ | |
"data": { | |
"application/vnd.jupyter.widget-view+json": { | |
"model_id": "0a9948a37f2a4e5b9fc676b461bafcd9", | |
"version_major": 2, | |
"version_minor": 0 | |
}, | |
"text/plain": [ | |
"HBox(children=(FloatProgress(value=0.0), HTML(value='')))" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
}, | |
{ | |
"data": { | |
"application/vnd.jupyter.widget-view+json": { | |
"model_id": "5263134767fc4dc7a9acc7a13754cc7e", | |
"version_major": 2, | |
"version_minor": 0 | |
}, | |
"text/plain": [ | |
"HBox(children=(FloatProgress(value=0.0), HTML(value='')))" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"\n", | |
"\n" | |
] | |
}, | |
{ | |
"data": { | |
"application/vnd.jupyter.widget-view+json": { | |
"model_id": "996a3c4d18924feeab3956e4bb66a73e", | |
"version_major": 2, | |
"version_minor": 0 | |
}, | |
"text/plain": [ | |
"HBox(children=(FloatProgress(value=0.0), HTML(value='')))" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
}, | |
{ | |
"data": { | |
"application/vnd.jupyter.widget-view+json": { | |
"model_id": "573000250a9549dbbaa10db6c3e87f72", | |
"version_major": 2, | |
"version_minor": 0 | |
}, | |
"text/plain": [ | |
"HBox(children=(FloatProgress(value=0.0), HTML(value='')))" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"\n", | |
"\n" | |
] | |
}, | |
{ | |
"data": { | |
"application/vnd.jupyter.widget-view+json": { | |
"model_id": "0ecaed702ef84323af0e5e0cf362cc41", | |
"version_major": 2, | |
"version_minor": 0 | |
}, | |
"text/plain": [ | |
"HBox(children=(FloatProgress(value=0.0), HTML(value='')))" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
}, | |
{ | |
"data": { | |
"application/vnd.jupyter.widget-view+json": { | |
"model_id": "c82004e79f024f28944e2b65901dbf5f", | |
"version_major": 2, | |
"version_minor": 0 | |
}, | |
"text/plain": [ | |
"HBox(children=(FloatProgress(value=0.0), HTML(value='')))" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"\n", | |
"\n" | |
] | |
}, | |
{ | |
"data": { | |
"application/vnd.jupyter.widget-view+json": { | |
"model_id": "bc40ade6b9b445fbb3dc5f294a7999a5", | |
"version_major": 2, | |
"version_minor": 0 | |
}, | |
"text/plain": [ | |
"HBox(children=(FloatProgress(value=0.0), HTML(value='')))" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
}, | |
{ | |
"data": { | |
"application/vnd.jupyter.widget-view+json": { | |
"model_id": "aae3b7f9027e4d60910f7846f35ec257", | |
"version_major": 2, | |
"version_minor": 0 | |
}, | |
"text/plain": [ | |
"HBox(children=(FloatProgress(value=0.0), HTML(value='')))" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"\n", | |
"\n" | |
] | |
}, | |
{ | |
"data": { | |
"application/vnd.jupyter.widget-view+json": { | |
"model_id": "8390901f336e447e9f6db3d3629e6bc0", | |
"version_major": 2, | |
"version_minor": 0 | |
}, | |
"text/plain": [ | |
"HBox(children=(FloatProgress(value=0.0), HTML(value='')))" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
}, | |
{ | |
"data": { | |
"application/vnd.jupyter.widget-view+json": { | |
"model_id": "94f959efb76e4cf2b6452ac7ac29bc6f", | |
"version_major": 2, | |
"version_minor": 0 | |
}, | |
"text/plain": [ | |
"HBox(children=(FloatProgress(value=0.0), HTML(value='')))" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"\n" | |
] | |
}, | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"\r", | |
" 8%|▊ | 1/12 [01:30<16:38, 90.82s/it]" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"\n" | |
] | |
}, | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"\r", | |
" 17%|█▋ | 2/12 [01:31<10:37, 63.72s/it]" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"\n" | |
] | |
}, | |
{ | |
"data": { | |
"application/vnd.jupyter.widget-view+json": { | |
"model_id": "bea932db365c4bce8a12e97f1161a0d7", | |
"version_major": 2, | |
"version_minor": 0 | |
}, | |
"text/plain": [ | |
"HBox(children=(FloatProgress(value=0.0), HTML(value='')))" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"\n" | |
] | |
}, | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"\r", | |
" 25%|██▌ | 3/12 [02:00<08:00, 53.39s/it]" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"\n" | |
] | |
}, | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"\r", | |
" 33%|███▎ | 4/12 [02:01<05:01, 37.68s/it]" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"\n" | |
] | |
}, | |
{ | |
"data": { | |
"application/vnd.jupyter.widget-view+json": { | |
"model_id": "ae894eaefc6d4e6f886dc805d0803cc9", | |
"version_major": 2, | |
"version_minor": 0 | |
}, | |
"text/plain": [ | |
"HBox(children=(FloatProgress(value=0.0), HTML(value='')))" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"\n" | |
] | |
}, | |
{ | |
"data": { | |
"application/vnd.jupyter.widget-view+json": { | |
"model_id": "c1faacee5c1f4f22add2bc1e3bc12369", | |
"version_major": 2, | |
"version_minor": 0 | |
}, | |
"text/plain": [ | |
"HBox(children=(FloatProgress(value=0.0), HTML(value='')))" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"\n", | |
"\n" | |
] | |
}, | |
{ | |
"data": { | |
"application/vnd.jupyter.widget-view+json": { | |
"model_id": "837c91d720ee42078223a321ceb760ea", | |
"version_major": 2, | |
"version_minor": 0 | |
}, | |
"text/plain": [ | |
"HBox(children=(FloatProgress(value=0.0), HTML(value='')))" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
}, | |
{ | |
"data": { | |
"application/vnd.jupyter.widget-view+json": { | |
"model_id": "44dccedd0b8f46229312895140476edf", | |
"version_major": 2, | |
"version_minor": 0 | |
}, | |
"text/plain": [ | |
"HBox(children=(FloatProgress(value=0.0), HTML(value='')))" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"\n" | |
] | |
}, | |
{ | |
"data": { | |
"application/vnd.jupyter.widget-view+json": { | |
"model_id": "3df807e6344c42ad9f887954c72382ab", | |
"version_major": 2, | |
"version_minor": 0 | |
}, | |
"text/plain": [ | |
"HBox(children=(FloatProgress(value=0.0), HTML(value='')))" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"\n" | |
] | |
}, | |
{ | |
"data": { | |
"application/vnd.jupyter.widget-view+json": { | |
"model_id": "71baa2df1c3b4045a16adeefa75160fb", | |
"version_major": 2, | |
"version_minor": 0 | |
}, | |
"text/plain": [ | |
"HBox(children=(FloatProgress(value=0.0), HTML(value='')))" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"\n" | |
] | |
}, | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"\r", | |
" 42%|████▏ | 5/12 [03:59<07:11, 61.60s/it]" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"\n" | |
] | |
}, | |
{ | |
"data": { | |
"application/vnd.jupyter.widget-view+json": { | |
"model_id": "2e12d07616d046378a97212208826db9", | |
"version_major": 2, | |
"version_minor": 0 | |
}, | |
"text/plain": [ | |
"HBox(children=(FloatProgress(value=0.0), HTML(value='')))" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"\n" | |
] | |
}, | |
{ | |
"data": { | |
"application/vnd.jupyter.widget-view+json": { | |
"model_id": "cd3467973a3e43cb91c96bda86792988", | |
"version_major": 2, | |
"version_minor": 0 | |
}, | |
"text/plain": [ | |
"HBox(children=(FloatProgress(value=0.0), HTML(value='')))" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"\n" | |
] | |
}, | |
{ | |
"data": { | |
"application/vnd.jupyter.widget-view+json": { | |
"model_id": "3632209ca30342319310cd2b55b3c66e", | |
"version_major": 2, | |
"version_minor": 0 | |
}, | |
"text/plain": [ | |
"HBox(children=(FloatProgress(value=0.0), HTML(value='')))" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"\n" | |
] | |
}, | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"\r", | |
" 50%|█████ | 6/12 [05:34<07:10, 71.69s/it]" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"\n", | |
"\n" | |
] | |
}, | |
{ | |
"data": { | |
"application/vnd.jupyter.widget-view+json": { | |
"model_id": "acb66f649cfa4533aab961b516602ceb", | |
"version_major": 2, | |
"version_minor": 0 | |
}, | |
"text/plain": [ | |
"HBox(children=(FloatProgress(value=0.0), HTML(value='')))" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
}, | |
{ | |
"data": { | |
"application/vnd.jupyter.widget-view+json": { | |
"model_id": "f1eb4fc37d074473a2840f4d18522c65", | |
"version_major": 2, | |
"version_minor": 0 | |
}, | |
"text/plain": [ | |
"HBox(children=(FloatProgress(value=0.0), HTML(value='')))" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"\n" | |
] | |
}, | |
{ | |
"data": { | |
"application/vnd.jupyter.widget-view+json": { | |
"model_id": "b4722bd60b4a4961bee9d5d2d9becc5a", | |
"version_major": 2, | |
"version_minor": 0 | |
}, | |
"text/plain": [ | |
"HBox(children=(FloatProgress(value=0.0), HTML(value='')))" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"\n" | |
] | |
}, | |
{ | |
"data": { | |
"application/vnd.jupyter.widget-view+json": { | |
"model_id": "cbf1e2b4537b45f19c32c1f9809cbf07", | |
"version_major": 2, | |
"version_minor": 0 | |
}, | |
"text/plain": [ | |
"HBox(children=(FloatProgress(value=0.0), HTML(value='')))" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"\n" | |
] | |
}, | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"\r", | |
" 58%|█████▊ | 7/12 [07:48<07:31, 90.36s/it]" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"\n" | |
] | |
}, | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"\r", | |
" 67%|██████▋ | 8/12 [07:50<04:15, 63.84s/it]" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"\n" | |
] | |
}, | |
{ | |
"data": { | |
"application/vnd.jupyter.widget-view+json": { | |
"model_id": "516952ec8b0a4f5695c407877bd8872e", | |
"version_major": 2, | |
"version_minor": 0 | |
}, | |
"text/plain": [ | |
"HBox(children=(FloatProgress(value=0.0), HTML(value='')))" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"\n" | |
] | |
}, | |
{ | |
"data": { | |
"application/vnd.jupyter.widget-view+json": { | |
"model_id": "1f646ebe314c410c81832b41de227514", | |
"version_major": 2, | |
"version_minor": 0 | |
}, | |
"text/plain": [ | |
"HBox(children=(FloatProgress(value=0.0), HTML(value='')))" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"\n" | |
] | |
}, | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"\r", | |
" 75%|███████▌ | 9/12 [09:09<03:25, 68.60s/it]" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"\n" | |
] | |
}, | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"\r", | |
" 83%|████████▎ | 10/12 [10:12<02:13, 66.91s/it]" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"\n" | |
] | |
}, | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"\r", | |
" 92%|█████████▏| 11/12 [10:22<00:49, 49.68s/it]" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"\n" | |
] | |
}, | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"\r", | |
"100%|██████████| 12/12 [11:26<00:00, 54.08s/it]" | |
] | |
} | |
], | |
"source": [ | |
"params = param_product(**param_search)\n", | |
"params = list(params)[0:12]\n", | |
"\n", | |
"results = grid_search(X[0:10000,:], params, n_processes=12)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 9, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/html": [ | |
"<div>\n", | |
"<style scoped>\n", | |
" .dataframe tbody tr th:only-of-type {\n", | |
" vertical-align: middle;\n", | |
" }\n", | |
"\n", | |
" .dataframe tbody tr th {\n", | |
" vertical-align: top;\n", | |
" }\n", | |
"\n", | |
" .dataframe thead th {\n", | |
" text-align: right;\n", | |
" }\n", | |
"</style>\n", | |
"<table border=\"1\" class=\"dataframe\">\n", | |
" <thead>\n", | |
" <tr style=\"text-align: right;\">\n", | |
" <th></th>\n", | |
" <th>K</th>\n", | |
" <th>D</th>\n", | |
" <th>observations</th>\n", | |
" <th>transitions</th>\n", | |
" <th>observation_kwargs</th>\n", | |
" <th>log_likelihoods</th>\n", | |
" </tr>\n", | |
" </thead>\n", | |
" <tbody>\n", | |
" <tr>\n", | |
" <th>0</th>\n", | |
" <td>5</td>\n", | |
" <td>14</td>\n", | |
" <td>autoregressive</td>\n", | |
" <td>sticky</td>\n", | |
" <td>{'lags': 1}</td>\n", | |
" <td>[73.63600490820458, 73.66281742297728, 71.8680...</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>1</th>\n", | |
" <td>5</td>\n", | |
" <td>14</td>\n", | |
" <td>autoregressive</td>\n", | |
" <td>standard</td>\n", | |
" <td>{'lags': 1}</td>\n", | |
" <td>[73.64167739121682, 74.97694287706454, 72.0947...</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>2</th>\n", | |
" <td>7</td>\n", | |
" <td>14</td>\n", | |
" <td>autoregressive</td>\n", | |
" <td>sticky</td>\n", | |
" <td>{'lags': 1}</td>\n", | |
" <td>[73.67920859553763, 70.55818373405485, 74.6496...</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>3</th>\n", | |
" <td>7</td>\n", | |
" <td>14</td>\n", | |
" <td>autoregressive</td>\n", | |
" <td>standard</td>\n", | |
" <td>{'lags': 1}</td>\n", | |
" <td>[73.71081642856838, 70.63887450049349, 74.4475...</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>4</th>\n", | |
" <td>5</td>\n", | |
" <td>14</td>\n", | |
" <td>autoregressive</td>\n", | |
" <td>recurrent</td>\n", | |
" <td>{'lags': 1}</td>\n", | |
" <td>[73.63214419000296, 73.6926339542364, 72.05606...</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>5</th>\n", | |
" <td>7</td>\n", | |
" <td>14</td>\n", | |
" <td>autoregressive</td>\n", | |
" <td>recurrent</td>\n", | |
" <td>{'lags': 1}</td>\n", | |
" <td>[73.76921456088444, 70.81079060428786, 74.6286...</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>6</th>\n", | |
" <td>5</td>\n", | |
" <td>14</td>\n", | |
" <td>autoregressive</td>\n", | |
" <td>standard</td>\n", | |
" <td>{'lags': 6}</td>\n", | |
" <td>[76.4593292441041, 77.44156131923047, 69.78935...</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>7</th>\n", | |
" <td>5</td>\n", | |
" <td>14</td>\n", | |
" <td>autoregressive</td>\n", | |
" <td>sticky</td>\n", | |
" <td>{'lags': 6}</td>\n", | |
" <td>[77.16071708938951, 74.50652839566231, 67.1754...</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>8</th>\n", | |
" <td>5</td>\n", | |
" <td>14</td>\n", | |
" <td>autoregressive</td>\n", | |
" <td>recurrent</td>\n", | |
" <td>{'lags': 6}</td>\n", | |
" <td>[77.14710032562544, 76.64457651074046, 70.9074...</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>9</th>\n", | |
" <td>7</td>\n", | |
" <td>14</td>\n", | |
" <td>autoregressive</td>\n", | |
" <td>sticky</td>\n", | |
" <td>{'lags': 6}</td>\n", | |
" <td>[55.10422483160579, -4.791925793428352, 49.359...</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>10</th>\n", | |
" <td>7</td>\n", | |
" <td>14</td>\n", | |
" <td>autoregressive</td>\n", | |
" <td>standard</td>\n", | |
" <td>{'lags': 6}</td>\n", | |
" <td>[57.016592852844255, -6.453637578048345, 48.95...</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>11</th>\n", | |
" <td>7</td>\n", | |
" <td>14</td>\n", | |
" <td>autoregressive</td>\n", | |
" <td>recurrent</td>\n", | |
" <td>{'lags': 6}</td>\n", | |
" <td>[57.57047274312927, 22.043312475814187, 47.546...</td>\n", | |
" </tr>\n", | |
" </tbody>\n", | |
"</table>\n", | |
"</div>" | |
], | |
"text/plain": [ | |
" K D observations transitions observation_kwargs \\\n", | |
"0 5 14 autoregressive sticky {'lags': 1} \n", | |
"1 5 14 autoregressive standard {'lags': 1} \n", | |
"2 7 14 autoregressive sticky {'lags': 1} \n", | |
"3 7 14 autoregressive standard {'lags': 1} \n", | |
"4 5 14 autoregressive recurrent {'lags': 1} \n", | |
"5 7 14 autoregressive recurrent {'lags': 1} \n", | |
"6 5 14 autoregressive standard {'lags': 6} \n", | |
"7 5 14 autoregressive sticky {'lags': 6} \n", | |
"8 5 14 autoregressive recurrent {'lags': 6} \n", | |
"9 7 14 autoregressive sticky {'lags': 6} \n", | |
"10 7 14 autoregressive standard {'lags': 6} \n", | |
"11 7 14 autoregressive recurrent {'lags': 6} \n", | |
"\n", | |
" log_likelihoods \n", | |
"0 [73.63600490820458, 73.66281742297728, 71.8680... \n", | |
"1 [73.64167739121682, 74.97694287706454, 72.0947... \n", | |
"2 [73.67920859553763, 70.55818373405485, 74.6496... \n", | |
"3 [73.71081642856838, 70.63887450049349, 74.4475... \n", | |
"4 [73.63214419000296, 73.6926339542364, 72.05606... \n", | |
"5 [73.76921456088444, 70.81079060428786, 74.6286... \n", | |
"6 [76.4593292441041, 77.44156131923047, 69.78935... \n", | |
"7 [77.16071708938951, 74.50652839566231, 67.1754... \n", | |
"8 [77.14710032562544, 76.64457651074046, 70.9074... \n", | |
"9 [55.10422483160579, -4.791925793428352, 49.359... \n", | |
"10 [57.016592852844255, -6.453637578048345, 48.95... \n", | |
"11 [57.57047274312927, 22.043312475814187, 47.546... " | |
] | |
}, | |
"execution_count": 9, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"results" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 15, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"results['mean_log_likelihood'] = [np.mean(vals) for vals in results.log_likelihoods.values]\n", | |
"results['sd_log_likelihood'] = [np.std(vals) for vals in results.log_likelihoods.values]" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 17, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/html": [ | |
"<div>\n", | |
"<style scoped>\n", | |
" .dataframe tbody tr th:only-of-type {\n", | |
" vertical-align: middle;\n", | |
" }\n", | |
"\n", | |
" .dataframe tbody tr th {\n", | |
" vertical-align: top;\n", | |
" }\n", | |
"\n", | |
" .dataframe thead th {\n", | |
" text-align: right;\n", | |
" }\n", | |
"</style>\n", | |
"<table border=\"1\" class=\"dataframe\">\n", | |
" <thead>\n", | |
" <tr style=\"text-align: right;\">\n", | |
" <th></th>\n", | |
" <th>K</th>\n", | |
" <th>D</th>\n", | |
" <th>observations</th>\n", | |
" <th>transitions</th>\n", | |
" <th>observation_kwargs</th>\n", | |
" <th>log_likelihoods</th>\n", | |
" <th>mean_log_likelihood</th>\n", | |
" <th>sd_log_likelihood</th>\n", | |
" </tr>\n", | |
" </thead>\n", | |
" <tbody>\n", | |
" <tr>\n", | |
" <th>10</th>\n", | |
" <td>7</td>\n", | |
" <td>14</td>\n", | |
" <td>autoregressive</td>\n", | |
" <td>standard</td>\n", | |
" <td>{'lags': 6}</td>\n", | |
" <td>[57.016592852844255, -6.453637578048345, 48.95...</td>\n", | |
" <td>33.172303</td>\n", | |
" <td>28.212443</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>9</th>\n", | |
" <td>7</td>\n", | |
" <td>14</td>\n", | |
" <td>autoregressive</td>\n", | |
" <td>sticky</td>\n", | |
" <td>{'lags': 6}</td>\n", | |
" <td>[55.10422483160579, -4.791925793428352, 49.359...</td>\n", | |
" <td>33.223894</td>\n", | |
" <td>26.983362</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>11</th>\n", | |
" <td>7</td>\n", | |
" <td>14</td>\n", | |
" <td>autoregressive</td>\n", | |
" <td>recurrent</td>\n", | |
" <td>{'lags': 6}</td>\n", | |
" <td>[57.57047274312927, 22.043312475814187, 47.546...</td>\n", | |
" <td>42.386693</td>\n", | |
" <td>14.955732</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>3</th>\n", | |
" <td>7</td>\n", | |
" <td>14</td>\n", | |
" <td>autoregressive</td>\n", | |
" <td>standard</td>\n", | |
" <td>{'lags': 1}</td>\n", | |
" <td>[73.71081642856838, 70.63887450049349, 74.4475...</td>\n", | |
" <td>72.932398</td>\n", | |
" <td>1.649417</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>7</th>\n", | |
" <td>5</td>\n", | |
" <td>14</td>\n", | |
" <td>autoregressive</td>\n", | |
" <td>sticky</td>\n", | |
" <td>{'lags': 6}</td>\n", | |
" <td>[77.16071708938951, 74.50652839566231, 67.1754...</td>\n", | |
" <td>72.947558</td>\n", | |
" <td>4.222897</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>2</th>\n", | |
" <td>7</td>\n", | |
" <td>14</td>\n", | |
" <td>autoregressive</td>\n", | |
" <td>sticky</td>\n", | |
" <td>{'lags': 1}</td>\n", | |
" <td>[73.67920859553763, 70.55818373405485, 74.6496...</td>\n", | |
" <td>72.962356</td>\n", | |
" <td>1.745563</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>0</th>\n", | |
" <td>5</td>\n", | |
" <td>14</td>\n", | |
" <td>autoregressive</td>\n", | |
" <td>sticky</td>\n", | |
" <td>{'lags': 1}</td>\n", | |
" <td>[73.63600490820458, 73.66281742297728, 71.8680...</td>\n", | |
" <td>73.055620</td>\n", | |
" <td>0.839819</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>5</th>\n", | |
" <td>7</td>\n", | |
" <td>14</td>\n", | |
" <td>autoregressive</td>\n", | |
" <td>recurrent</td>\n", | |
" <td>{'lags': 1}</td>\n", | |
" <td>[73.76921456088444, 70.81079060428786, 74.6286...</td>\n", | |
" <td>73.069553</td>\n", | |
" <td>1.635270</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>4</th>\n", | |
" <td>5</td>\n", | |
" <td>14</td>\n", | |
" <td>autoregressive</td>\n", | |
" <td>recurrent</td>\n", | |
" <td>{'lags': 1}</td>\n", | |
" <td>[73.63214419000296, 73.6926339542364, 72.05606...</td>\n", | |
" <td>73.126948</td>\n", | |
" <td>0.757630</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>1</th>\n", | |
" <td>5</td>\n", | |
" <td>14</td>\n", | |
" <td>autoregressive</td>\n", | |
" <td>standard</td>\n", | |
" <td>{'lags': 1}</td>\n", | |
" <td>[73.64167739121682, 74.97694287706454, 72.0947...</td>\n", | |
" <td>73.571137</td>\n", | |
" <td>1.177691</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>6</th>\n", | |
" <td>5</td>\n", | |
" <td>14</td>\n", | |
" <td>autoregressive</td>\n", | |
" <td>standard</td>\n", | |
" <td>{'lags': 6}</td>\n", | |
" <td>[76.4593292441041, 77.44156131923047, 69.78935...</td>\n", | |
" <td>74.563416</td>\n", | |
" <td>3.399502</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>8</th>\n", | |
" <td>5</td>\n", | |
" <td>14</td>\n", | |
" <td>autoregressive</td>\n", | |
" <td>recurrent</td>\n", | |
" <td>{'lags': 6}</td>\n", | |
" <td>[77.14710032562544, 76.64457651074046, 70.9074...</td>\n", | |
" <td>74.899693</td>\n", | |
" <td>2.830421</td>\n", | |
" </tr>\n", | |
" </tbody>\n", | |
"</table>\n", | |
"</div>" | |
], | |
"text/plain": [ | |
" K D observations transitions observation_kwargs \\\n", | |
"10 7 14 autoregressive standard {'lags': 6} \n", | |
"9 7 14 autoregressive sticky {'lags': 6} \n", | |
"11 7 14 autoregressive recurrent {'lags': 6} \n", | |
"3 7 14 autoregressive standard {'lags': 1} \n", | |
"7 5 14 autoregressive sticky {'lags': 6} \n", | |
"2 7 14 autoregressive sticky {'lags': 1} \n", | |
"0 5 14 autoregressive sticky {'lags': 1} \n", | |
"5 7 14 autoregressive recurrent {'lags': 1} \n", | |
"4 5 14 autoregressive recurrent {'lags': 1} \n", | |
"1 5 14 autoregressive standard {'lags': 1} \n", | |
"6 5 14 autoregressive standard {'lags': 6} \n", | |
"8 5 14 autoregressive recurrent {'lags': 6} \n", | |
"\n", | |
" log_likelihoods mean_log_likelihood \\\n", | |
"10 [57.016592852844255, -6.453637578048345, 48.95... 33.172303 \n", | |
"9 [55.10422483160579, -4.791925793428352, 49.359... 33.223894 \n", | |
"11 [57.57047274312927, 22.043312475814187, 47.546... 42.386693 \n", | |
"3 [73.71081642856838, 70.63887450049349, 74.4475... 72.932398 \n", | |
"7 [77.16071708938951, 74.50652839566231, 67.1754... 72.947558 \n", | |
"2 [73.67920859553763, 70.55818373405485, 74.6496... 72.962356 \n", | |
"0 [73.63600490820458, 73.66281742297728, 71.8680... 73.055620 \n", | |
"5 [73.76921456088444, 70.81079060428786, 74.6286... 73.069553 \n", | |
"4 [73.63214419000296, 73.6926339542364, 72.05606... 73.126948 \n", | |
"1 [73.64167739121682, 74.97694287706454, 72.0947... 73.571137 \n", | |
"6 [76.4593292441041, 77.44156131923047, 69.78935... 74.563416 \n", | |
"8 [77.14710032562544, 76.64457651074046, 70.9074... 74.899693 \n", | |
"\n", | |
" sd_log_likelihood \n", | |
"10 28.212443 \n", | |
"9 26.983362 \n", | |
"11 14.955732 \n", | |
"3 1.649417 \n", | |
"7 4.222897 \n", | |
"2 1.745563 \n", | |
"0 0.839819 \n", | |
"5 1.635270 \n", | |
"4 0.757630 \n", | |
"1 1.177691 \n", | |
"6 3.399502 \n", | |
"8 2.830421 " | |
] | |
}, | |
"execution_count": 17, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"results.sort_values(by=\"mean_log_likelihood\")" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 22, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"<AxesSubplot:xlabel='K', ylabel='mean_log_likelihood'>" | |
] | |
}, | |
"execution_count": 22, | |
"metadata": {}, | |
"output_type": "execute_result" | |
}, | |
{ | |
"data": { | |
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAX8AAAEGCAYAAACNaZVuAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuMCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy86wFpkAAAACXBIWXMAAAsTAAALEwEAmpwYAAAUyUlEQVR4nO3dfZAcdZ3H8fc3CYkxQQhJjEiIQaGw8AyRWylQRJRTEY/EE8WnEkQoyqrzqSwVlCqf/vAh5d35WNblsDywfECJCFqKIoh6ougiIaiIBAQJQggxQMJDWNjv/TG9sFk22e7NdM/u9vtVNc50z/TMJ0P7md7f9HRHZiJJapdpvQ4gSWqe5S9JLWT5S1ILWf6S1EKWvyS10IxeByhrwYIFuXTp0l7HkKRJ5eqrr747MxeOnD9pyn/p0qX09/f3OoYkTSoRceto8x32kaQWsvwlqYUsf0lqIctfklrI8pekFpry5b9523auve0eNm/b3usokjRhTJpdPcfjorW3c+aadewxbRoDg4OsOnEZK5bv1+tYktRzU3bLf/O27Zy5Zh0PDQyydfsjPDQwyAfWrPMvAEliCpf/hi0PVpovSW0yZct/zszpPDQwuMO8hwYGmTNzeo8SSdLEMWXL//6HH2XW9Nhh3qzpwf0PP9qjRJI0cUzZ8l88bzYxbcfyj2nB4nmze5RIkiaOKVv+8+fOYtWJy5g1YxpPnjmdWTOmserEZcyfO6vX0SSp56Zs+QPk0P/m41OSpClc/kO7em5/JHlg4FG2P5Lu6ilpUlm/cSsX9N/G+o1bu/7cU/ZHXhu2PMj2EXv7bB8YZMOWBx36kTThffh713Heb/722PTJRy7h4yuf27Xnn7Jb/ndvfegJAz1ZzJekiWz9xq07FD/Aeb/+W1f/Apiy5X/thnsrzZekieL/1t9daf54TNnyP/qgBZXmS9JEsWDuzErzx2PKln/fAfN50YHzd5j3ogPn03fA/J0sIUkTw5HPWsCInykxLTrzu2XKfuEL8LXTj6D/r5v5xY13c/RBCyx+SZPC/Lmz+Ozrl/O+71xLRJCZfOZ1h3Z1Z5XInBz7v/f19WV/f3+vY0hSYzZv286GLQ+yeN7scRd/RFydmX0j50/pLX9Jmszmz51V267pU3bMX5K0c5a/JLWQ5S9JLWT5S1ILWf6S1EK1ln9EHBwRa4dd7ouI90TEPhFxaUTcWFzPqzOHJGlHtZZ/Zt6Qmcszcznwz8ADwIXAWcBlmXkQcFkxLUlqSJPDPscCN2XmrcBK4Nxi/rnAqxvMIUmt12T5vwH4ZnF7UWbeUdy+E1g02gIRcUZE9EdE/6ZNm5rIKEmt0Ej5R8RMYAXwnZH3Zef4EqMeYyIzV2dmX2b2LVy4sOaUktQeTW35vxL4fWZuLKY3RsS+AMX1XQ3lkCTRXPm/kceHfAAuBk4pbp8CXNRQDkkSDZR/RMwBXgZ8d9jsTwEvi4gbgX8ppiVJDan9qJ6ZeT8wf8S8zXT2/pEk9YC/8JWkFrL8JamFLH9JaiHLX5JayPKXpBay/CWphSx/SWohy1+SWsjyl6QWsvwlqYUsf0lqIctfklrI8pekFrL8JamFLH9JaiHLX5JayPKXpBay/CWphSx/SWohy1+SWsjyl6QWsvwlqYUsf0lqoRljPSAiXrOr+zPzu92LI0lqwpjlD5xQXD8VeAFweTH9EuBKwPKXpElmzPLPzFMBIuInwCGZeUcxvS/wv7WmkyTVosqY//5DxV/YCCzpch5JUgPKDPsMuSwifgx8s5h+PfDT7keSJNWtdPln5jsi4t+Ao4tZqzPzwnpiSZLqVGXLHzpf8D4CJPDb7seRJDWh9Jh/RJxEp/BfC5wEXBURr60rmCSpPlW2/M8Gnp+ZdwFExEI6Y/4X1BFMklSfKnv7TBsq/sLmistLkiaIKlv+l4yyt88Pux9JklS3Knv7vL841MNRxSz39pGkSarq3j6/AgZwbx9JmtTc20eSWsi9fSSphdzbR5JaqPa9fSJib+Ac4J/ofFfwNuAG4HxgKXALcFJmbqmQRZK0G0pvuWfm+4HVwLLisjozzyyx6OeASzLz2cChwPXAWcBlmXkQcFkxLUlqSKW9fTJzDbCm7OMjYi86B4J7a7H8w8DDEbESOKZ42LnAFUCZDxJJUhdU2dvnNRFxY0TcGxH3RcTWiLhvjMUOADYBX42IayLinIiYAywadm6AO4FFO3nNMyKiPyL6N23aVDaqJGkMVb6wXQWsyMy9MvMpmblnZj5ljGVmAIcBX87M5wH3M2KIJzOTzncBT5CZqzOzLzP7Fi5cWCGqJGlXqpT/xsy8vuLzbwA2ZOZVxfQFdD4MNhangRw6HeRdO1leklSDMcf8i0M6APRHxPnA94DtQ/dn5k5P4J6Zd0bEbRFxcGbeABwL/Km4nAJ8qri+aNz/AklSZWW+8D1h2O0HgJcPm05gp+VfeCfw9YiYCdwMnErnL45vR8RpwK10fjEsSWrImOWfmafuzgtk5lqgb5S7jt2d55UkjV+ZYZ8PZOaqiPgCo3wxm5nvqiWZJKk2ZYZ9hr7k7a8ziCSpOWWGfb5fXJ9bfxxJUhPKDPt8n53shw+QmSu6mkiSVLsywz6fqT2FJKlRZYZ9fj50OyJmA0uKffYlSZNUlWP7nACsBS4pppdHxMU15ZIk1ajK4R0+ChwO3AOP7b9/QNcTSZJqV6X8BzLz3hHzdvpFsCRp4qpyPP8/RsSbgOkRcRDwLuDKemJJkupUZcv/ncBz6BzU7RvAfcC76wglSapXlfJ/Y2aenZnPLy5nAx+rK5gkqT5Vhn1OjIiHMvPrABHxRWB2PbEkSXWqVP7AxRExCBwH3JOZp9UTS5JUpzKHd9hn2OTpdE7m8ivgYxGxT2b+o6ZskqSalNnyv5rOLp0x7PpVxSWBZ9aWTpJUizKHd/CHXJI0xZQZ9nlpZl4+7Fy+O9jVOXwlSRNTmWGfFwOXs+O5fIeUOYevJGmCKTPs85HierfO5StJmjjKDPu8d1f3Z+Z/di+OJKkJZYZ99qw9hSSpUWWGfUodwiEiPpiZn9z9SJKkulU5ts9YXtfF55Ik1aib5R9dfC5JUo26Wf6e2EWSJgm3/CWphbpZ/t/p4nNJkmpU+pDOEfH5UWbfC/Rn5kWZ+YnuxZIk1anKlv+TgOXAjcVlGbAYOC0iPtv1ZJKk2lQ5mcsy4IWZ+ShARHwZ+CVwFHBdDdkkSTWpsuU/D5g7bHoOsE/xYbC9q6kkSbWqsuW/ClgbEVfQ2bPnaOATETEH+GkN2SRJNSld/pn5lYj4IXB4MetDmfn34vb7u55MklSbKlv+AM8HXlTcHgT+vovHSpImqNJj/hHxKeDdwJ+Ky7siwt07JWkSqrLlfzywPDMHASLiXOAa4EN1BJMk1afqL3z3HnZ7ry7mkCQ1qMqW/yeBayLiZzy+t89ZYy0UEbcAW4FHgUcysy8i9gHOB5YCtwAnZeaWSsklSeNWess/M78JHEHnhO1rgCMz8/ySi78kM5dnZl8xfRZwWWYeBFxGiQ8RSVL3lDmH72EjZm0orp8eEU/PzN+P43VXAscUt88FrgDOHMfzSJLGocywz3/s4r4EXjrG8gn8JCIS+O/MXA0sysw7ivvvBBaNtmBEnAGcAbBkyZISUSVJZZQ5h+9LyjxRRLwsMy8d5a6jMvP2iHgqcGlE/HnE82fxwTDaa68GVgP09fV5shhJ6pJuHs//06PNzMzbi+u7gAvp/EJ4Y0TsC1Bc39XFHJKkMdR6Jq+ImBMRew7dBl4O/AG4GDileNgpwEVdzCFJGkPVwzvsymjDMouACyNi6LW+kZmXRMTvgG9HxGnArcBJXcwhSRpDN8v/CTLzZuDQUeZvBo6t87UlSTvXzWGfW7r4XJKkGlXa8o+IF9D5Ve5jy2XmecX1a7qaTJJUmyoncP8a8CxgLZ1DNUBnnP+87seSJNWpypZ/H3BIZrq/vSRNclXG/P8APK2uIJKk5lTZ8l8A/CkifsuwE7Zn5oqup5Ik1apK+X+0rhCSpGZVOYH7z+sMIklqTpVz+B4REb+LiG0R8XBEPBoR99UZTpJUjypf+H4ReCNwIzAbOB34Uh2hJEn1qvQL38xcD0zPzEcz86vAcfXEkiTVqcoXvg9ExExgbUSsAu6gu4eHkCQ1pEp5v6V4/DuA+4H9gRPrCCVJqleVvX1ujYjZwL6Z+bEaM0mSalZlb58T6BzX55JienlEXFxTLklSjaoM+3yUzikY7wHIzLXAAV1PJEmqXZXyH8jMe0fM8yBvkjQJVdnb548R8SZgekQcBLwLuLKeWJKkOlXZ8n8n8Bw6B3X7BnAv8O46QkmS6lWl/A8pLjOAJwErgd/VEUqSVK8qwz5fB95H57j+g/XEkSQ1oUr5b8rM79eWRJLUmCrl/5GIOAe4jB1P5vLdrqeSJNWqSvmfCjwb2IPHh30SsPwlaZKpUv7Pz8yDa0siSWpMlb19royIQ2pLIklqTJUt/yPoHM75r3TG/APIzFxWSzJJUm2qlL8nbpGkBm3etp0NWx5k8bzZzJ87q6vPXemQzl19ZUnSTl209nbOXLOOPaZNY2BwkFUnLmPF8v269vyeiUuSJpjN27Zz5pp1PDQwyNbtj/DQwCAfWLOOzdu2j71wSZa/JE0wG7Y8yB7TdqznPaZNY8OWB7v2Gpa/JE0wi+fNZmBwx6PoDAwOsnje7K69huUvSRPM/LmzWHXiMp60xzT2nDWDJ+0xjVUnLuvql75V9vaRJDVkxfL9eOGBC3q/t48kqVnz587qeukPcdhHklrI8pekFrL8JamFGin/iJgeEddExA+K6QMi4qqIWB8R50fEzCZySJI6mtryfzdw/bDpTwP/lZkHAluA0xrKIUmigfKPiMXAq4BziukAXgpcUDzkXODVdeeQJD2uiS3/zwIf4PGzf80H7snMR4rpDUD3jlYkSRpTreUfEf8K3JWZV49z+TMioj8i+jdt2tTldJLUXnVv+b8QWBERtwDfojPc8zlg74gY+oHZYuD20RbOzNWZ2ZeZfQsXLqw5qiS1R63ln5kfzMzFmbkUeANweWa+GfgZ8NriYacAF9WZQ5K0o17t538m8N6IWE/nO4Cv9CiHJLVSY8f2ycwrgCuK2zcDhzf12pKkHfkLX0lqIctfklrI8pekFrL8JamFLH9JaiHLX5JayPKXpBay/CWphSx/SWohy1+SWsjyl6QWsvwlqYUsf0lqIctfklrI8pekFrL8JamFLH9JaiHLX5JayPKXpBay/CWphSx/SWohy1+SWsjyl6QWsvwlqYUsf0maoNZv3MoF/bexfuPWrj/3jK4/oyRpt334e9dx3m/+9tj0yUcu4eMrn9u153fLX5ImmPUbt+5Q/ADn/fpvXf0LwPKXpAlm7W33VJo/Hpa/JE0wy/ffu9L88bD8JWmCOXDRnpx85JId5p185BIOXLRn117DL3wlaQL6+MrncvIRS1l72z0s33/vrhY/WP6SNGEduGjPrpf+EId9JKmFLH9JaiHLX5JayPKXpBay/CWphSIze52hlIjYBNw6zsUXAHd3MU63mKsac1Vjrmqmaq5nZObCkTMnTfnvjojoz8y+XucYyVzVmKsac1XTtlwO+0hSC1n+ktRCbSn/1b0OsBPmqsZc1ZirmlblasWYvyRpR23Z8pckDWP5S1ILTfryj4hbIuK6iFgbEf2j3B8R8fmIWB8R6yLisGH3nRIRNxaXUxrO9eYiz3URcWVEHFp22ZpzHRMR9xb3r42IDw+777iIuKF4L89qONf7h2X6Q0Q8GhH7lFl2N3PtHREXRMSfI+L6iDhyxP29Wr/GytWr9WusXL1av8bK1fj6FREHD3vNtRFxX0S8Z8Rj6lu/MnNSX4BbgAW7uP944EdAAEcAVxXz9wFuLq7nFbfnNZjrBUOvB7xyKFeZZWvOdQzwg1HmTwduAp4JzASuBQ5pKteIx54AXN7Q+3UucHpxeyaw9wRZv8bK1av1a6xcvVq/dpmrV+vXiH//nXR+kNXI+jXpt/xLWAmclx2/AfaOiH2BVwCXZuY/MnMLcClwXFOhMvPK4nUBfgMsbuq1x+lwYH1m3pyZDwPfovPe9sIbgW/W/SIRsRdwNPAVgMx8ODPvGfGwxtevMrl6sX6VfL92prb1axy5Glm/RjgWuCkzRx7FoLb1ayqUfwI/iYirI+KMUe7fD7ht2PSGYt7O5jeVa7jT6Hy6j2fZOnIdGRHXRsSPIuI5xbwJ8X5FxJPprORrqi47DgcAm4CvRsQ1EXFORMwZ8ZherF9lcg3X1PpVNlfT61fp96vh9Wu4NzD6B05t69dUKP+jMvMwOn/a/ntEHN3rQIVSuSLiJXT+z3lm1WVryvV7On96Hgp8AfheF197d3INOQH4VWb+YxzLVjUDOAz4cmY+D7gf6OpY9DiVztXw+lUmVy/Wryr/HZtcvwCIiJnACuA73XzesUz68s/M24vru4AL6fz5ONztwP7DphcX83Y2v6lcRMQy4BxgZWZurrJsXbky877M3Fbc/iGwR0QsYAK8X4UnbCHV+H5tADZk5lXF9AV0SmS4XqxfZXL1Yv0aM1eP1q9S71ehyfVryCuB32fmxlHuq239mtTlHxFzImLPodvAy4E/jHjYxcDJxbfmRwD3ZuYdwI+Bl0fEvIiYVyz746ZyRcQS4LvAWzLzLxX/TXXmelpERHH7cDrryGbgd8BBEXFAsaXyBjrvbSO5ivv2Al4MXFR12fHIzDuB2yLi4GLWscCfRjys8fWrTK5erF8lczW+fpX879j4+jXMrr5jqG/92t1vqXt5obNnwLXF5Y/A2cX8twNvL24H8CU6exJcB/QNW/5twPricmrDuc4BtgBri0v/rpZtMNc7ivuupfNF4QuGLX888JfivWw0VzH9VuBbZZbtYrblQD+wjs4Qxbxer18lczW+fpXM1fj6VSZXD9evOXQ+/PYaNq+R9cvDO0hSC03qYR9J0vhY/pLUQpa/JLWQ5S9JLWT5S1ILWf7SOEXEtmG3j4+Iv0TEM3qZSSprRq8DSJNdRBwLfB54RT7xwFzShGT5S7uhOM7L/wDHZ+ZNvc4jleWPvKRxiogBYCtwTGau63UeqQrH/KXxGwCupHPUTGlSsfyl8RsETgIOj4gP9TqMVIVj/tJuyMwHIuJVwC8jYmNmfqXXmaQyLH9pN2XmPyLiOOAXEbEpM7tyKGKpTn7hK0kt5Ji/JLWQ5S9JLWT5S1ILWf6S1EKWvyS1kOUvSS1k+UtSC/0/JImDh+CkV0IAAAAASUVORK5CYII=\n", | |
"text/plain": [ | |
"<Figure size 432x288 with 1 Axes>" | |
] | |
}, | |
"metadata": { | |
"needs_background": "light" | |
}, | |
"output_type": "display_data" | |
} | |
], | |
"source": [ | |
"results.plot(x='K', y='mean_log_likelihood',kind='scatter')" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 3", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.7.7" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 4 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment