Last active
October 14, 2021 17:46
-
-
Save fwhigh/c6f9c88cf94cedf2e96d6900ac0f1226 to your computer and use it in GitHub Desktop.
Blog post: lightgbm-vs-keras-metaflow
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Make a directory you can blow away in full later | |
mkdir -p aicamp_demo && cd aicamp_demo | |
# Clone and pin the fwhigh/metaflow-helper git repo | |
git clone https://github.com/fwhigh/metaflow-helper.git | |
cd metaflow-helper | |
git checkout v0.0.1 | |
# Set up and active a virtual environment | |
# Install the metaflow-helper package in editable model and dependencies | |
python -m venv venv && . venv/bin/activate | |
python -m pip install --upgrade pip | |
python -m pip install -e . | |
brew install lightgbm | |
python -m pip install -r example-requirements.txt | |
# Test runs and flow visualization | |
python examples/model-selection/train.py run --help | |
python examples/model-selection/train.py run --configuration test_randomized_config | |
brew install graphviz | |
python examples/model-selection/train.py output-dot | dot -Tpng -o model-selection-flow.png | |
# Full run | |
python examples/model-selection/train.py run --configuration randomized_config |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from importlib import import_module | |
import subprocess | |
def system_command_with_retry(cmd: list): | |
for i in range(0, 5): | |
wait_seconds = 2 ** i | |
try: | |
status = subprocess.run(cmd) | |
if status.returncode != 0: | |
print(f'command status was {status}, retrying after {wait_seconds} seconds') | |
time.sleep(wait_seconds) | |
continue | |
except subprocess.CalledProcessError: | |
print(f'command failed, retrying after {wait_seconds} seconds') | |
time.sleep(wait_seconds) | |
continue | |
break | |
def install_dependencies(dependencies: list): | |
for dependency in dependencies: | |
for k, v in dependency.items(): | |
try: | |
module_ = import_module(k) | |
except ModuleNotFoundError: | |
system_command_with_retry(['pip', 'install', v]) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"id": "f40ac1dc-7271-45bb-ae44-9112faffdebd", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"%load_ext autoreload\n", | |
"%autoreload 2" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 2, | |
"id": "a4dad677-2266-4bd3-bdd3-a95eeae0e1a6", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"# metaflow was installed from PyPI using `pip install metaflow`\n", | |
"from metaflow import Metaflow, Flow, Run, Step\n", | |
"\n", | |
"# metaflow_helper is my local package installed using `pip install -e .` at this repo's top level\n", | |
"from metaflow_helper.models import KerasRegressor\n", | |
"from metaflow_helper.feature_engineer import FeatureEngineer\n", | |
"\n", | |
"# This is the common.py script that is local to this flow subdirectory\n", | |
"import common" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 3, | |
"id": "b7bb5881-9af4-4006-985b-55bbb7a5b2e6", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"[Flow('Train')]\n" | |
] | |
} | |
], | |
"source": [ | |
"print(Metaflow().flows)\n" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 4, | |
"id": "e695fa83-b14e-4823-9fcd-cf5faee1e936", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"[Run('Train/1621360258148759'), Run('Train/1621360113161437')]\n", | |
"1621360258148759\n" | |
] | |
} | |
], | |
"source": [ | |
"# These are the class names of your flow specs\n", | |
"flow_map = {'TRAIN': 'Train', 'PREDICT': 'Predict'}\n", | |
"# Pick one to debug in this notebook\n", | |
"debug_key = 'TRAIN'\n", | |
"# What step are you debugging?\n", | |
"debug_step_name = 'train'\n", | |
"# What artifact are you looking at?\n", | |
"debug_artifact_name = ''\n", | |
"flow = Flow(flow_map[debug_key])\n", | |
"print(list(flow))\n", | |
"# What run ID?\n", | |
"run_id = list(flow)[0].id # Get the latest run ID\n", | |
"#run_id = '1621360113161437' # Or fully qualify the run ID if you know it\n", | |
"print(run_id)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 5, | |
"id": "b156008c-d51a-4f2e-a085-827be179d671", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"<MetaflowData: test_feature_engineer, train_validation_index, categorical_features, folds, name, df, score, modeler, best_iterations, k_fold, configuration, best_contender, numeric_features, make_regression_init_kwargs, test_index, feature_engineer, contenders, test_modeler, best_contender_ser, contender_results>" | |
] | |
}, | |
"execution_count": 5, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"data = list(Step(f'{flow_map[debug_key]}/{run_id}/{debug_step_name}'))[0].data\n", | |
"data" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 6, | |
"id": "a73767ef-7a8e-4e28-85ed-7b05c7d62075", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/html": [ | |
"<div>\n", | |
"<style scoped>\n", | |
" .dataframe tbody tr th:only-of-type {\n", | |
" vertical-align: middle;\n", | |
" }\n", | |
"\n", | |
" .dataframe tbody tr th {\n", | |
" vertical-align: top;\n", | |
" }\n", | |
"\n", | |
" .dataframe thead th {\n", | |
" text-align: right;\n", | |
" }\n", | |
"</style>\n", | |
"<table border=\"1\" class=\"dataframe\">\n", | |
" <thead>\n", | |
" <tr style=\"text-align: right;\">\n", | |
" <th></th>\n", | |
" <th>target</th>\n", | |
" <th>num_0</th>\n", | |
" <th>num_1</th>\n", | |
" <th>num_2</th>\n", | |
" <th>num_3</th>\n", | |
" <th>num_4</th>\n", | |
" <th>num_5</th>\n", | |
" <th>num_6</th>\n", | |
" <th>num_7</th>\n", | |
" <th>num_8</th>\n", | |
" <th>num_9</th>\n", | |
" <th>cat</th>\n", | |
" </tr>\n", | |
" </thead>\n", | |
" <tbody>\n", | |
" <tr>\n", | |
" <th>0</th>\n", | |
" <td>-218.853766</td>\n", | |
" <td>0.396007</td>\n", | |
" <td>-1.093062</td>\n", | |
" <td>0.539249</td>\n", | |
" <td>-0.769916</td>\n", | |
" <td>-0.208299</td>\n", | |
" <td>-0.635846</td>\n", | |
" <td>-0.674333</td>\n", | |
" <td>0.576591</td>\n", | |
" <td>0.676433</td>\n", | |
" <td>0.031831</td>\n", | |
" <td>type_0</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>1</th>\n", | |
" <td>188.815412</td>\n", | |
" <td>1.849591</td>\n", | |
" <td>-0.214167</td>\n", | |
" <td>0.265688</td>\n", | |
" <td>-0.451303</td>\n", | |
" <td>0.019279</td>\n", | |
" <td>0.719984</td>\n", | |
" <td>0.723100</td>\n", | |
" <td>-0.101697</td>\n", | |
" <td>-1.102906</td>\n", | |
" <td>0.024612</td>\n", | |
" <td>type_0</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>2</th>\n", | |
" <td>122.535122</td>\n", | |
" <td>-0.852586</td>\n", | |
" <td>0.022960</td>\n", | |
" <td>0.645055</td>\n", | |
" <td>-0.532490</td>\n", | |
" <td>1.681922</td>\n", | |
" <td>0.468385</td>\n", | |
" <td>1.011842</td>\n", | |
" <td>-0.667713</td>\n", | |
" <td>1.735879</td>\n", | |
" <td>-0.657951</td>\n", | |
" <td>type_0</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>3</th>\n", | |
" <td>-25.764331</td>\n", | |
" <td>-0.370704</td>\n", | |
" <td>-0.945616</td>\n", | |
" <td>-2.320594</td>\n", | |
" <td>0.286904</td>\n", | |
" <td>-1.318396</td>\n", | |
" <td>0.225609</td>\n", | |
" <td>0.317161</td>\n", | |
" <td>-0.067276</td>\n", | |
" <td>0.449712</td>\n", | |
" <td>0.520041</td>\n", | |
" <td>type_0</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>4</th>\n", | |
" <td>-467.003639</td>\n", | |
" <td>0.352817</td>\n", | |
" <td>-0.152774</td>\n", | |
" <td>-1.401347</td>\n", | |
" <td>-0.817493</td>\n", | |
" <td>-0.263937</td>\n", | |
" <td>-1.226622</td>\n", | |
" <td>1.030438</td>\n", | |
" <td>-0.055353</td>\n", | |
" <td>0.967446</td>\n", | |
" <td>-2.047324</td>\n", | |
" <td>type_0</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>...</th>\n", | |
" <td>...</td>\n", | |
" <td>...</td>\n", | |
" <td>...</td>\n", | |
" <td>...</td>\n", | |
" <td>...</td>\n", | |
" <td>...</td>\n", | |
" <td>...</td>\n", | |
" <td>...</td>\n", | |
" <td>...</td>\n", | |
" <td>...</td>\n", | |
" <td>...</td>\n", | |
" <td>...</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>95</th>\n", | |
" <td>-723.943434</td>\n", | |
" <td>-0.535270</td>\n", | |
" <td>0.025405</td>\n", | |
" <td>-1.153950</td>\n", | |
" <td>-2.582797</td>\n", | |
" <td>-0.406072</td>\n", | |
" <td>-1.032643</td>\n", | |
" <td>-0.347962</td>\n", | |
" <td>-1.642965</td>\n", | |
" <td>-0.436748</td>\n", | |
" <td>-1.353389</td>\n", | |
" <td>type_0</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>96</th>\n", | |
" <td>2.167896</td>\n", | |
" <td>1.466579</td>\n", | |
" <td>0.852552</td>\n", | |
" <td>-0.222675</td>\n", | |
" <td>0.567290</td>\n", | |
" <td>1.141102</td>\n", | |
" <td>-0.291837</td>\n", | |
" <td>-0.353432</td>\n", | |
" <td>0.857924</td>\n", | |
" <td>-0.761492</td>\n", | |
" <td>-1.616474</td>\n", | |
" <td>type_0</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>97</th>\n", | |
" <td>-386.092904</td>\n", | |
" <td>0.115148</td>\n", | |
" <td>-0.379148</td>\n", | |
" <td>0.417319</td>\n", | |
" <td>-1.550429</td>\n", | |
" <td>-1.660700</td>\n", | |
" <td>-1.405963</td>\n", | |
" <td>-0.944368</td>\n", | |
" <td>-0.110489</td>\n", | |
" <td>-0.590058</td>\n", | |
" <td>0.238103</td>\n", | |
" <td>type_0</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>98</th>\n", | |
" <td>-203.314236</td>\n", | |
" <td>-0.387327</td>\n", | |
" <td>-0.302303</td>\n", | |
" <td>0.378163</td>\n", | |
" <td>0.154947</td>\n", | |
" <td>1.202380</td>\n", | |
" <td>-0.347912</td>\n", | |
" <td>-0.887786</td>\n", | |
" <td>1.230291</td>\n", | |
" <td>0.156349</td>\n", | |
" <td>-1.980796</td>\n", | |
" <td>type_0</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>99</th>\n", | |
" <td>181.612333</td>\n", | |
" <td>-0.131054</td>\n", | |
" <td>1.133080</td>\n", | |
" <td>-1.829740</td>\n", | |
" <td>0.453782</td>\n", | |
" <td>-1.118312</td>\n", | |
" <td>0.589880</td>\n", | |
" <td>0.037006</td>\n", | |
" <td>-0.805627</td>\n", | |
" <td>-0.363859</td>\n", | |
" <td>0.767902</td>\n", | |
" <td>type_0</td>\n", | |
" </tr>\n", | |
" </tbody>\n", | |
"</table>\n", | |
"<p>100 rows × 12 columns</p>\n", | |
"</div>" | |
], | |
"text/plain": [ | |
" target num_0 num_1 num_2 num_3 num_4 num_5 \\\n", | |
"0 -218.853766 0.396007 -1.093062 0.539249 -0.769916 -0.208299 -0.635846 \n", | |
"1 188.815412 1.849591 -0.214167 0.265688 -0.451303 0.019279 0.719984 \n", | |
"2 122.535122 -0.852586 0.022960 0.645055 -0.532490 1.681922 0.468385 \n", | |
"3 -25.764331 -0.370704 -0.945616 -2.320594 0.286904 -1.318396 0.225609 \n", | |
"4 -467.003639 0.352817 -0.152774 -1.401347 -0.817493 -0.263937 -1.226622 \n", | |
".. ... ... ... ... ... ... ... \n", | |
"95 -723.943434 -0.535270 0.025405 -1.153950 -2.582797 -0.406072 -1.032643 \n", | |
"96 2.167896 1.466579 0.852552 -0.222675 0.567290 1.141102 -0.291837 \n", | |
"97 -386.092904 0.115148 -0.379148 0.417319 -1.550429 -1.660700 -1.405963 \n", | |
"98 -203.314236 -0.387327 -0.302303 0.378163 0.154947 1.202380 -0.347912 \n", | |
"99 181.612333 -0.131054 1.133080 -1.829740 0.453782 -1.118312 0.589880 \n", | |
"\n", | |
" num_6 num_7 num_8 num_9 cat \n", | |
"0 -0.674333 0.576591 0.676433 0.031831 type_0 \n", | |
"1 0.723100 -0.101697 -1.102906 0.024612 type_0 \n", | |
"2 1.011842 -0.667713 1.735879 -0.657951 type_0 \n", | |
"3 0.317161 -0.067276 0.449712 0.520041 type_0 \n", | |
"4 1.030438 -0.055353 0.967446 -2.047324 type_0 \n", | |
".. ... ... ... ... ... \n", | |
"95 -0.347962 -1.642965 -0.436748 -1.353389 type_0 \n", | |
"96 -0.353432 0.857924 -0.761492 -1.616474 type_0 \n", | |
"97 -0.944368 -0.110489 -0.590058 0.238103 type_0 \n", | |
"98 -0.887786 1.230291 0.156349 -1.980796 type_0 \n", | |
"99 0.037006 -0.805627 -0.363859 0.767902 type_0 \n", | |
"\n", | |
"[100 rows x 12 columns]" | |
] | |
}, | |
"execution_count": 6, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"df = data.df\n", | |
"df" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 7, | |
"id": "14263569-96bc-47d8-b702-95cf1bd44325", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"<metaflow_helper.feature_engineer.FeatureEngineer at 0x15061fed0>" | |
] | |
}, | |
"execution_count": 7, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"# This is code from my locally installed, editable package\n", | |
"# Changes made to member functions are immediately available without re-importing\n", | |
"local_feature_engineer = FeatureEngineer(\n", | |
" pipeline_fn=common.build_preprocessor_pipeline,\n", | |
" numeric_features=data.numeric_features,\n", | |
" categorical_features=data.categorical_features,\n", | |
")\n", | |
"local_feature_engineer\n" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 8, | |
"id": "8436182e-363c-4006-b177-101ba94b278d", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"<metaflow_helper.feature_engineer.FeatureEngineer at 0x108229f50>" | |
] | |
}, | |
"execution_count": 8, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"# This is a deserialized Metaflow artifact\n", | |
"# It reflects the state of the object from the Metaflow run \n", | |
"# I can compare results from my edited local version to this one from the run\n", | |
"metaflow_feature_engineer = data.feature_engineer\n", | |
"metaflow_feature_engineer" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"id": "231d4b9e-5dbd-4915-8b50-edec6a118d8f", | |
"metadata": {}, | |
"outputs": [], | |
"source": [] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 3", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.7.10" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 5 | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from scipy.stats import randint, loguniform | |
contenders_spec = [ | |
{ | |
# This is the algo | |
'__model': ['metaflow_helper.models.LightGBMRegressor'], | |
# These go to the model initializer | |
'__model__init_kwargs__learning_rate': loguniform(1e-2, 1e-1), | |
'__model__init_kwargs__max_depth': randint(1, 4), | |
'__model__init_kwargs__n_estimators': [10_000], | |
# These go to the model fitter | |
'__model__fit_kwargs__eval_metric': ['mse'], | |
'__model__fit_kwargs__early_stopping_rounds': [10], | |
'__model__fit_kwargs__verbose': [0], | |
# The presence of this key triggers randomized search | |
'__n_iter': 5, | |
}, | |
{ | |
# This is the algo | |
'__model': ['metaflow_helper.models.KerasRegressor'], | |
# These go to the model initializer | |
'__model__init_kwargs__build_model': ['metaflow_helper.models.build_keras_regression_model'], | |
'__model__init_kwargs__metric': ['mse'], | |
'__model__init_kwargs__dense_layer_widths': [(), (15,), (15, 15,), (15 * 15,)], | |
'__model__init_kwargs__l1_factor': loguniform(1e-8, 1e-2), | |
'__model__init_kwargs__l2_factor': loguniform(1e-8, 1e-2), | |
# These go to the model fitter | |
'__model__fit_kwargs__batch_size': [None], | |
'__model__fit_kwargs__epochs': [10_000], | |
'__model__fit_kwargs__validation_split': [0.2], | |
'__model__fit_kwargs__monitor': ['val_mse'], | |
'__model__fit_kwargs__verbose': [0], | |
'__model__fit_kwargs__patience': [10], | |
'__model__fit_kwargs__min_delta': [0.1], | |
# The presence of this key triggers randomized search | |
'__n_iter': 5, | |
}, | |
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
python -m venv metaflow-helper-venv && . metaflow-helper-venv/bin/activate | |
# checkout a tagged commit | |
git clone https://github.com/fwhigh/metaflow-helper.git | |
cd metaflow-helper | |
git checkout v0.0.1 | |
# the package is also available via `pip install metaflow-helper==0.0.1` | |
python -m pip install --upgrade pip | |
python -m pip install -e . |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
python -m pip install -r example-requirements.txt | |
python examples/model-selection/train.py run --help | |
# --configuration TEXT Which config.py file to use. Available configs: | |
# randomized_config (default), | |
# test_randomized_config, grid_config, | |
# test_grid_config [default: randomized_config] | |
python examples/model-selection/train.py run --configuration test_randomized_config |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from metaflow import FlowSpec, Parameter, step | |
import commmon | |
class Train(FlowSpec): | |
test_mode = Parameter( | |
'test_mode', | |
help="Run in test mode?", | |
type=bool, | |
default=False, | |
) | |
@step | |
def start(self): | |
if self.test_mode: | |
# Get a subset of data and reduce parallelism here | |
self.df = common.get_dataframe(max_rows=100) | |
self.max_epochs = 10 | |
self.patience = 1 | |
else: | |
self.df = common.get_dataframe() | |
self.epochs = 10_000 | |
self.patience = 50 | |
# Do stuff here | |
self.next(self.end) | |
@step | |
def end(self): | |
pass | |
if __name__ == '__main__': | |
Train() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from metaflow import FlowSpec, step | |
import common | |
class Train(FlowSpec): | |
@step | |
def start(self): | |
self.df = common.get_df() | |
# Do stuff here | |
self.next(self.end) | |
@step | |
def end(self): | |
pass | |
if __name__ == '__main__': | |
Train() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment