Last active
February 17, 2018 23:50
-
-
Save MInner/ea549d7eb6b53cd8467d8d3b3a5d0d0b to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"## API showcase" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 112, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"[ 0.88690007 0.55995198 0.38361163 0.68403184 0.0185156 0.44171983\n", | |
" 0.31691619 0.02937568 0.00836834 0.78880494]\n", | |
"err\n", | |
"[ 0.7695575 0.32766094 0.23568485 0.20659444 0.71498636 0.24506314\n", | |
" 0.304722 0.57193584 0.61933445 0.24320615]\n", | |
"err\n", | |
"[0 0 0 0 0 0 0 0 0 0]\n", | |
"err\n", | |
"[[ 0.82250916 0.40552113 0.8301193 0.53392686 0.95118346 0.23415553\n", | |
" 0.5256737 0.24420565 0.82948549 0.86105571 0.77893775 0.46440709\n", | |
" 0.00869608 0.06291643 0.71569858 0.40135511 0.55258019 0.08750497\n", | |
" 0.1939549 0.40050567]\n", | |
" [ 0.08454198 0.6199718 0.79065577 0.79124625 0.87329852 0.89822278\n", | |
" 0.34901504 0.47744634 0.98617159 0.86818536 0.298728 0.94149995\n", | |
" 0.67553114 0.5602536 0.72955612 0.89932782 0.42081279 0.40602921\n", | |
" 0.89699766 0.83561404]\n", | |
" [ 0.59831343 0.35428523 0.78468076 0.18829544 0.61758936 0.14843397\n", | |
" 0.00596944 0.94308102 0.3325512 0.90827114 0.26069004 0.48808149\n", | |
" 0.63955319 0.5244585 0.57968004 0.46761899 0.47280379 0.74128998\n", | |
" 0.17967916 0.44656813]\n", | |
" [ 0.50523911 0.07422026 0.26768233 0.11226564 0.65410621 0.96435585\n", | |
" 0.05394845 0.14270239 0.5248568 0.78780012 0.90629328 0.27774761\n", | |
" 0.98978354 0.05009231 0.58057635 0.66617815 0.79489857 0.2342683\n", | |
" 0.83614349 0.01591979]\n", | |
" [ 0.04143335 0.61623312 0.64599089 0.41537789 0.23042656 0.43502103\n", | |
" 0.2977826 0.61752906 0.02659585 0.12120256 0.00599285 0.83223846\n", | |
" 0.18163289 0.22369256 0.23993126 0.28848517 0.56280879 0.04957049\n", | |
" 0.88140353 0.25244182]\n", | |
" [ 0.86919285 0.19760094 0.47385413 0.67153636 0.12651067 0.03326239\n", | |
" 0.18196367 0.01481561 0.40641939 0.66161099 0.33279351 0.15837677\n", | |
" 0.39526333 0.29394726 0.61702674 0.96257239 0.17626561 0.8929886\n", | |
" 0.87638873 0.06589202]\n", | |
" [ 0.67906735 0.50630077 0.69801478 0.13452838 0.62129216 0.20966711\n", | |
" 0.65118637 0.93092802 0.06673653 0.58469732 0.51138083 0.81524443\n", | |
" 0.61470865 0.23558816 0.55740143 0.39730801 0.44250091 0.9901409\n", | |
" 0.31236297 0.52510855]\n", | |
" [ 0.07341446 0.76972049 0.97572684 0.77730214 0.33045504 0.18508604\n", | |
" 0.22107951 0.38452896 0.7564924 0.67878656 0.45544665 0.94758416\n", | |
" 0.59257704 0.64189236 0.59865144 0.67881047 0.00933364 0.50467215\n", | |
" 0.63850554 0.3517403 ]\n", | |
" [ 0.73160472 0.75799142 0.67398197 0.37578636 0.88886733 0.63472282\n", | |
" 0.77206326 0.40438063 0.52069013 0.57340567 0.46108637 0.60409608\n", | |
" 0.67011748 0.09200092 0.15267226 0.35911962 0.34958308 0.13727258\n", | |
" 0.76444041 0.16011395]\n", | |
" [ 0.04931594 0.17037391 0.4409902 0.03624572 0.52629626 0.48542883\n", | |
" 0.91848478 0.64623007 0.64216531 0.48021692 0.96896661 0.94137667\n", | |
" 0.34870942 0.71793587 0.79579319 0.20487346 0.47547493 0.90524316\n", | |
" 0.15253653 0.7574594 ]]\n", | |
"err\n", | |
"[[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]\n", | |
" [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]\n", | |
" [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]\n", | |
" [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]\n", | |
" [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]\n", | |
" [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]\n", | |
" [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]\n", | |
" [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]\n", | |
" [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]\n", | |
" [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]]\n", | |
"err\n", | |
"[[[0 0 0 0]\n", | |
" [0 0 0 0]\n", | |
" [0 0 0 0]]\n", | |
"\n", | |
" [[0 0 0 0]\n", | |
" [0 0 0 0]\n", | |
" [0 0 0 0]]]\n", | |
"err\n" | |
] | |
} | |
], | |
"source": [ | |
"import numpy as np\n", | |
"from typecheck import typecheck\n", | |
"\n", | |
"## NDArr object is specified in Definition section above\n", | |
"\n", | |
"## you can specify desired shape via [..] syntax\n", | |
"## or by passing shape=(..) value to constructor\n", | |
"@typecheck\n", | |
"def nice_func1(a:NDArr[10]):\n", | |
" return a\n", | |
"\n", | |
"print(nice_func1(np.random.rand(10))) ## works!\n", | |
"try:\n", | |
" print(nice_func1(np.random.rand(20))) ## fails!\n", | |
"except:\n", | |
" print('err')\n", | |
"\n", | |
"## does same thing as above\n", | |
"@typecheck\n", | |
"def nice_func1_1(a:NDArr(shape=(10,)) ):\n", | |
" return a\n", | |
"\n", | |
"print(nice_func1_1(np.random.rand(10))) ## works!\n", | |
"try:\n", | |
" print(nice_func1_1(np.random.rand(20))) ## fails!\n", | |
"except:\n", | |
" print('err')\n", | |
"\n", | |
"## you can also specify dtype restrictions\n", | |
"@typecheck\n", | |
"def nice_func2(a:NDArr(dtype=int)):\n", | |
" return a\n", | |
"\n", | |
"print(nice_func2(np.random.rand(10).astype(int))) ## works!\n", | |
"try:\n", | |
" print(nice_func2(np.random.rand(10))) ## fails!\n", | |
"except:\n", | |
" print('err')\n", | |
" \n", | |
"## by passing : value via [..] syntax you can\n", | |
"## restrict only subset of dimentions\n", | |
"@typecheck\n", | |
"def nice_func3(a:NDArr[:, 20]):\n", | |
" return a\n", | |
"\n", | |
"print(nice_func3(np.random.rand(10, 20))) ## works!\n", | |
"try:\n", | |
" print(nice_func3(np.random.rand(10, 10))) ## fails!\n", | |
"except:\n", | |
" print('err')\n", | |
"\n", | |
"## you can also do both!\n", | |
"## restrictions are getting rewritten\\added on demand\n", | |
"@typecheck\n", | |
"def nice_func4(a:NDArr[:, 20](dtype=int)):\n", | |
" return a\n", | |
"\n", | |
"print(nice_func4(np.random.rand(10, 20).astype(int))) ## works!\n", | |
"try:\n", | |
" print(nice_func4(np.random.rand(10, 20))) ## fails!\n", | |
"except:\n", | |
" print('err')\n", | |
" \n", | |
" \n", | |
"## or combine in any way\n", | |
"@typecheck\n", | |
"def nice_func5(a:NDArr(ndim=3, dtype=int)):\n", | |
" return a\n", | |
"print(nice_func5(np.random.rand(2, 3, 4).astype(int))) ## works!\n", | |
"try:\n", | |
" print(nice_func5(np.random.rand(2, 3).astype(int))) ## fails!\n", | |
"except:\n", | |
" print('err')" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"## Definition" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 111, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [ | |
"import numpy as np\n", | |
"from typecheck import typecheck, InputParameterError\n", | |
"\n", | |
"import abc\n", | |
"from typing import Callable, Any\n", | |
"\n", | |
"class ParameterizedTypeHintPredicate(object):\n", | |
" @abc.abstractmethod\n", | |
" def check(self, argument):\n", | |
" pass\n", | |
" \n", | |
" def __call__(self, argument=None, **kwargs):\n", | |
" if argument is not None:\n", | |
" return self.check(argument)\n", | |
" elif len(kwargs) > 0:\n", | |
" param_dict = self.__dict__.copy() # otherwise original class dameged too!\n", | |
" param_dict.update(kwargs)\n", | |
" return self.__class__(**param_dict)\n", | |
" else:\n", | |
" raise ValueError(\"Not enouth params for constructor\")\n", | |
" \n", | |
"class NumpyNdArrayTypeChecker(ParameterizedTypeHintPredicate):\n", | |
" def __init__(self, shape=None, dtype=None, ndim=None, dim_restrictions=None):\n", | |
" self.shape = shape\n", | |
" self.dtype = dtype\n", | |
" self.ndim = ndim\n", | |
" self.dim_restrictions = dim_restrictions\n", | |
" \n", | |
" def check(self, argument):\n", | |
" if not isinstance(argument, np.ndarray):\n", | |
" return False\n", | |
" \n", | |
" def cmp_if_not_none(a, b):\n", | |
" if a is not None:\n", | |
" return a == b\n", | |
" return True\n", | |
" \n", | |
" flag = True\n", | |
" flag &= cmp_if_not_none(self.shape, argument.shape)\n", | |
" flag &= cmp_if_not_none(self.dtype, argument.dtype)\n", | |
" flag &= cmp_if_not_none(self.ndim, argument.ndim)\n", | |
" \n", | |
" if self.dim_restrictions is not None:\n", | |
" for dim_n, dim_size in self.dim_restrictions:\n", | |
" if argument.shape[dim_n] != dim_size:\n", | |
" return False\n", | |
" \n", | |
" return flag\n", | |
" \n", | |
" def __getitem__(self, shape_specs):\n", | |
" \"\"\"\n", | |
" shape_specs=[:, :, 3, :, :, 5] \n", | |
" =>\n", | |
" ndim = 6\n", | |
" dim_restrictions=[(2, 3), (5, 5)]\n", | |
" # second dimention is of shape 3, fifth of shape 5\n", | |
" \"\"\"\n", | |
" if isinstance(shape_specs, int):\n", | |
" shape_specs = [shape_specs]\n", | |
" \n", | |
" if isinstance(shape_specs, slice):\n", | |
" shape_specs = [shape_specs]\n", | |
" \n", | |
" dim_rst = []\n", | |
" for i, dim_spec in enumerate(shape_specs):\n", | |
" if not isinstance(dim_spec, slice):\n", | |
" dim_rst.append((i, dim_spec))\n", | |
" \n", | |
" return self(ndim=len(shape_specs), dim_restrictions=dim_rst)\n", | |
" \n", | |
"NDArr = NumpyNdArrayTypeChecker()" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"## Tests:" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 113, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"All tests have passed!\n" | |
] | |
} | |
], | |
"source": [ | |
"@typecheck\n", | |
"def test_fails(f:callable, err=Exception):\n", | |
" flag = True\n", | |
" try:\n", | |
" f()\n", | |
" except err:\n", | |
" flag = False\n", | |
" \n", | |
" if flag:\n", | |
" raise RuntimeError('Test failed! Try-catch shoud have '\n", | |
" 'raised error' + ('' if err is Exception else str(err)))\n", | |
"\n", | |
"@typecheck\n", | |
"def test_passes(f:callable):\n", | |
" try:\n", | |
" f()\n", | |
" except:\n", | |
" raise RuntimeError('Test failed! (must be no errors here)')\n", | |
"\n", | |
"def test():\n", | |
" ## test for test :) - internal test_fails should pop up error\n", | |
" test_fails(lambda: test_fails(lambda: 1/0, err=InputParameterError), err=ZeroDivisionError)\n", | |
" test_fails(lambda: test_passes(lambda: 1/0), err=RuntimeError)\n", | |
" \n", | |
" @typecheck\n", | |
" def func1(a: np.ndarray) -> np.ndarray :\n", | |
" return a\n", | |
"\n", | |
" test_passes(lambda: func1(np.arange(10)))\n", | |
" test_fails(lambda: func1('str'), err=InputParameterError)\n", | |
"\n", | |
" @typecheck\n", | |
" def func2(a: NDArr) -> NDArr :\n", | |
" return a\n", | |
"\n", | |
" test_passes(lambda: func2(np.arange(10)))\n", | |
" test_fails(lambda: func2('str'), err=InputParameterError)\n", | |
" \n", | |
" @typecheck\n", | |
" def func3(a: NDArr(shape=(10,))) -> NDArr :\n", | |
" return a\n", | |
"\n", | |
" test_passes(lambda: func3(np.arange(10)))\n", | |
" test_fails(lambda: func3('str'), err=InputParameterError)\n", | |
" test_fails(lambda: func3(np.arange(20)), err=InputParameterError)\n", | |
" \n", | |
" \n", | |
" @typecheck\n", | |
" def func4(a: NDArr(shape=(10,), dtype=np.int32)) -> NDArr :\n", | |
" return a\n", | |
"\n", | |
" test_passes(lambda: func4(np.arange(10, dtype=np.int32)))\n", | |
" test_fails(lambda: func4('str'), err=InputParameterError)\n", | |
" test_fails(lambda: func4(np.arange(20)), err=InputParameterError)\n", | |
" test_fails(lambda: func4(np.arange(10, dtype=float)), err=InputParameterError)\n", | |
" test_fails(lambda: func4(np.arange(10, dtype=int)))\n", | |
" \n", | |
" @typecheck\n", | |
" def func5(a: NDArr[10]) -> NDArr :\n", | |
" return a\n", | |
"\n", | |
" test_passes(lambda: func5(np.arange(10)))\n", | |
" test_fails(lambda: func5('str'), err=InputParameterError)\n", | |
" test_fails(lambda: func5(np.arange(20)), err=InputParameterError)\n", | |
" \n", | |
" @typecheck\n", | |
" def func5_1(a: NDArr[10, 20]) -> NDArr :\n", | |
" return a\n", | |
"\n", | |
" test_passes(lambda: func5_1(np.random.rand(10, 20)))\n", | |
" test_fails(lambda: func5_1('str'), err=InputParameterError)\n", | |
" test_fails(lambda: func5_1(np.arange(20)), err=InputParameterError)\n", | |
" test_fails(lambda: func5_1(np.random.rand(10, 30)), err=InputParameterError)\n", | |
" \n", | |
" @typecheck\n", | |
" def func6(a: NDArr[10, :]) -> NDArr :\n", | |
" return a\n", | |
"\n", | |
" test_passes(lambda: func6(np.random.rand(10, 20)))\n", | |
" test_passes(lambda: func6(np.random.rand(10, 30)))\n", | |
" test_fails(lambda: func6('str'), err=InputParameterError)\n", | |
" test_fails(lambda: func6(np.arange(20)), err=InputParameterError)\n", | |
" test_fails(lambda: func6(np.arange(10)), err=InputParameterError)\n", | |
" \n", | |
" @typecheck\n", | |
" def func6_1(a: NDArr[:]) -> NDArr :\n", | |
" return a\n", | |
"\n", | |
" test_passes(lambda: func6_1(np.random.rand(10)))\n", | |
" test_passes(lambda: func6_1(np.random.rand(20)))\n", | |
" test_fails(lambda: func6_1('str'), err=InputParameterError)\n", | |
" test_fails(lambda: func6_1(np.random.rand(10, 1)), err=InputParameterError)\n", | |
" test_fails(lambda: func6_1(np.random.rand(10, 20)), err=InputParameterError)\n", | |
" \n", | |
" @typecheck\n", | |
" def func7(a: NDArr(dtype=int)[10, :]) -> NDArr :\n", | |
" return a\n", | |
"\n", | |
" test_passes(lambda: func7(np.random.rand(10, 20).astype(dtype=int)))\n", | |
" test_passes(lambda: func7(np.random.rand(10, 30).astype(dtype=int)))\n", | |
" test_fails(lambda: func7('str'), err=InputParameterError)\n", | |
" test_fails(lambda: func7(np.arange(20)), err=InputParameterError)\n", | |
" test_fails(lambda: func7(np.arange(10)), err=InputParameterError)\n", | |
" test_fails(lambda: func7(np.random.rand(10, 20)))\n", | |
" test_fails(lambda: func7(np.random.rand(10, 30)))\n", | |
"\n", | |
" \n", | |
" @typecheck\n", | |
" def func8(a: NDArr[10, :](dtype=int)) -> NDArr :\n", | |
" return a\n", | |
"\n", | |
" test_passes(lambda: func8(np.random.rand(10, 20).astype(dtype=int)))\n", | |
" test_passes(lambda: func8(np.random.rand(10, 30).astype(dtype=int)))\n", | |
" test_fails(lambda: func8('str'), err=InputParameterError)\n", | |
" test_fails(lambda: func8(np.arange(20)), err=InputParameterError)\n", | |
" test_fails(lambda: func8(np.arange(10)), err=InputParameterError)\n", | |
" test_fails(lambda: func8(np.random.rand(10, 20)))\n", | |
" test_fails(lambda: func8(np.random.rand(10, 30)))\n", | |
"\n", | |
" print('All tests have passed!')\n", | |
" \n", | |
"test()" | |
] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 3", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.5.1" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 0 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment