Created
March 30, 2017 01:15
-
-
Save cicdw/e2de54ca17d615b263f80372031cb865 to your computer and use it in GitHub Desktop.
Trying to move towards better testing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [ | |
"import dask.array as da\n", | |
"import inspect\n", | |
"import numpy as np\n", | |
"\n", | |
"from dask_glm.algorithms import admm, gradient_descent, newton, proximal_grad\n", | |
"from dask_glm.families import Logistic, Normal\n", | |
"from dask_glm.utils import sigmoid" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 2, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [ | |
"## Normal family check\n", | |
"true_beta = (np.random.random(15) - 0.5) * 3\n", | |
"X = da.random.random((10000, 15), chunks=(2000, 15))\n", | |
"y = X.dot(true_beta)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 3, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Converged! 2\n", | |
"admm error: 0.007000568\n", | |
"newton error: 0.000000000\n", | |
"gradient_descent error: 0.125664047\n", | |
"proximal_grad error: 0.146943416\n" | |
] | |
} | |
], | |
"source": [ | |
"for algo in [admm, newton, gradient_descent, proximal_grad]:\n", | |
" sig = inspect.signature(newton)\n", | |
" if 'lamduh' in sig.parameters.keys():\n", | |
" out = algo(X, y, family=Normal, lamduh=0)\n", | |
" else:\n", | |
" out = algo(X, y, family=Normal)\n", | |
" perc_diff = (out - true_beta) / true_beta\n", | |
" print('{0} error: {1:.9f}'.format(algo.__name__, abs(perc_diff).max()))" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 4, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [ | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"//anaconda/lib/python3.5/site-packages/dask/array/core.py:476: RuntimeWarning: overflow encountered in true_divide\n", | |
" o = func(*args, **kwargs)\n" | |
] | |
} | |
], | |
"source": [ | |
"## Logistic family check\n", | |
"y = sigmoid(X.dot(true_beta))" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 5, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [ | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"//anaconda/lib/python3.5/site-packages/dask/array/core.py:476: RuntimeWarning: overflow encountered in true_divide\n", | |
" o = func(*args, **kwargs)\n" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Converged! 3\n", | |
"admm error: 0.083115191\n", | |
"newton error: 0.000000000\n" | |
] | |
}, | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"//anaconda/lib/python3.5/site-packages/dask/async.py:247: RuntimeWarning: overflow encountered in exp\n", | |
" return func(*args2)\n", | |
"//anaconda/lib/python3.5/site-packages/dask/array/core.py:476: RuntimeWarning: invalid value encountered in subtract\n", | |
" o = func(*args, **kwargs)\n" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"gradient_descent error: 0.196703062\n" | |
] | |
}, | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"//anaconda/lib/python3.5/site-packages/dask/array/core.py:476: RuntimeWarning: overflow encountered in exp\n", | |
" o = func(*args, **kwargs)\n" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"proximal_grad error: 0.288486711\n" | |
] | |
} | |
], | |
"source": [ | |
"for algo in [admm, newton, gradient_descent, proximal_grad]:\n", | |
" sig = inspect.signature(newton)\n", | |
" if 'lamduh' in sig.parameters.keys():\n", | |
" out = algo(X, y, family=Logistic, lamduh=0)\n", | |
" else:\n", | |
" out = algo(X, y, family=Logistic)\n", | |
" perc_diff = (out - true_beta) / true_beta\n", | |
" print('{0} error: {1:.9f}'.format(algo.__name__, abs(perc_diff).max()))" | |
] | |
} | |
], | |
"metadata": { | |
"anaconda-cloud": {}, | |
"kernelspec": { | |
"display_name": "Python [default]", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.5.2" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 1 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment