Created
June 10, 2020 06:30
-
-
Save ucalyptus/bcf09b711009b87e94f989cb13035909 to your computer and use it in GitHub Desktop.
skearn-random-forest-example.ipynb
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"nbformat": 4, | |
"nbformat_minor": 0, | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 3", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.7.6-final" | |
}, | |
"colab": { | |
"name": "skearn-random-forest-example.ipynb", | |
"provenance": [], | |
"include_colab_link": true | |
}, | |
"accelerator": "GPU" | |
}, | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "view-in-github", | |
"colab_type": "text" | |
}, | |
"source": [ | |
"<a href=\"https://colab.research.google.com/gist/ucalyptus/bcf09b711009b87e94f989cb13035909/skearn-random-forest-example.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "pIbz1CDs6i0y", | |
"colab_type": "code", | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 615 | |
}, | |
"outputId": "7433e265-65c0-4b1a-9f20-ee811cbd8775" | |
}, | |
"source": [ | |
"!pip install scikit-learn==0.21.3\n", | |
"!pip install hummingbird-ml" | |
], | |
"execution_count": 1, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"text": [ | |
"Collecting scikit-learn==0.21.3\n", | |
"\u001b[?25l Downloading https://files.pythonhosted.org/packages/a0/c5/d2238762d780dde84a20b8c761f563fe882b88c5a5fb03c056547c442a19/scikit_learn-0.21.3-cp36-cp36m-manylinux1_x86_64.whl (6.7MB)\n", | |
"\u001b[K |████████████████████████████████| 6.7MB 4.4MB/s \n", | |
"\u001b[?25hRequirement already satisfied: scipy>=0.17.0 in /usr/local/lib/python3.6/dist-packages (from scikit-learn==0.21.3) (1.4.1)\n", | |
"Requirement already satisfied: numpy>=1.11.0 in /usr/local/lib/python3.6/dist-packages (from scikit-learn==0.21.3) (1.18.5)\n", | |
"Requirement already satisfied: joblib>=0.11 in /usr/local/lib/python3.6/dist-packages (from scikit-learn==0.21.3) (0.15.1)\n", | |
"Installing collected packages: scikit-learn\n", | |
" Found existing installation: scikit-learn 0.22.2.post1\n", | |
" Uninstalling scikit-learn-0.22.2.post1:\n", | |
" Successfully uninstalled scikit-learn-0.22.2.post1\n", | |
"Successfully installed scikit-learn-0.21.3\n", | |
"Collecting hummingbird-ml\n", | |
"\u001b[?25l Downloading https://files.pythonhosted.org/packages/07/fc/df69fb8cb958826487112832964a9014fa1338dcf17b706bc099701bd524/hummingbird_ml-0.0.1-py2.py3-none-any.whl (49kB)\n", | |
"\u001b[K |████████████████████████████████| 51kB 2.6MB/s \n", | |
"\u001b[?25hRequirement already satisfied: xgboost==0.90 in /usr/local/lib/python3.6/dist-packages (from hummingbird-ml) (0.90)\n", | |
"Requirement already satisfied: torch>=1.4.0 in /usr/local/lib/python3.6/dist-packages (from hummingbird-ml) (1.5.0+cu101)\n", | |
"Collecting onnxconverter-common==1.6.0\n", | |
"\u001b[?25l Downloading https://files.pythonhosted.org/packages/77/3d/6112c19223d1eabbedf1b063567034e1463a11d7c82d1820f26b75d14e3c/onnxconverter_common-1.6.0-py2.py3-none-any.whl (43kB)\n", | |
"\u001b[K |████████████████████████████████| 51kB 5.7MB/s \n", | |
"\u001b[?25hRequirement already satisfied: numpy>=1.15.0 in /usr/local/lib/python3.6/dist-packages (from hummingbird-ml) (1.18.5)\n", | |
"Requirement already satisfied: lightgbm>=2.2.0 in /usr/local/lib/python3.6/dist-packages (from hummingbird-ml) (2.2.3)\n", | |
"Requirement already satisfied: scikit-learn==0.21.3 in /usr/local/lib/python3.6/dist-packages (from hummingbird-ml) (0.21.3)\n", | |
"Requirement already satisfied: scipy in /usr/local/lib/python3.6/dist-packages (from xgboost==0.90->hummingbird-ml) (1.4.1)\n", | |
"Requirement already satisfied: future in /usr/local/lib/python3.6/dist-packages (from torch>=1.4.0->hummingbird-ml) (0.16.0)\n", | |
"Collecting onnx\n", | |
"\u001b[?25l Downloading https://files.pythonhosted.org/packages/36/ee/bc7bc88fc8449266add978627e90c363069211584b937fd867b0ccc59f09/onnx-1.7.0-cp36-cp36m-manylinux1_x86_64.whl (7.4MB)\n", | |
"\u001b[K |████████████████████████████████| 7.4MB 12.1MB/s \n", | |
"\u001b[?25hRequirement already satisfied: six in /usr/local/lib/python3.6/dist-packages (from onnxconverter-common==1.6.0->hummingbird-ml) (1.12.0)\n", | |
"Requirement already satisfied: protobuf in /usr/local/lib/python3.6/dist-packages (from onnxconverter-common==1.6.0->hummingbird-ml) (3.10.0)\n", | |
"Requirement already satisfied: joblib>=0.11 in /usr/local/lib/python3.6/dist-packages (from scikit-learn==0.21.3->hummingbird-ml) (0.15.1)\n", | |
"Requirement already satisfied: typing-extensions>=3.6.2.1 in /usr/local/lib/python3.6/dist-packages (from onnx->onnxconverter-common==1.6.0->hummingbird-ml) (3.6.6)\n", | |
"Requirement already satisfied: setuptools in /usr/local/lib/python3.6/dist-packages (from protobuf->onnxconverter-common==1.6.0->hummingbird-ml) (47.1.1)\n", | |
"Installing collected packages: onnx, onnxconverter-common, hummingbird-ml\n", | |
"Successfully installed hummingbird-ml-0.0.1 onnx-1.7.0 onnxconverter-common-1.6.0\n" | |
], | |
"name": "stdout" | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "wLSQXKMx6Z56", | |
"colab_type": "code", | |
"colab": {} | |
}, | |
"source": [ | |
"import numpy as np\n", | |
"import torch\n", | |
"from sklearn.ensemble import RandomForestClassifier\n", | |
"from sklearn.datasets import load_breast_cancer\n", | |
"from hummingbird.ml import convert\n" | |
], | |
"execution_count": 0, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "DlnJvE556Z6G", | |
"colab_type": "code", | |
"colab": {} | |
}, | |
"source": [ | |
"# We are going to use the breast cancer dataset from scikit-learn for this example.\n", | |
"X, y = load_breast_cancer(return_X_y=True)\n", | |
"nrows=15000\n", | |
"X = X[0:nrows].astype('|f4')\n", | |
"y = y[0:nrows]" | |
], | |
"execution_count": 0, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "4_IoPJ3d6Z6P", | |
"colab_type": "code", | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 136 | |
}, | |
"outputId": "162f2ffc-0fd6-4701-c690-8e1f63de8541" | |
}, | |
"source": [ | |
"# Create and train a random forest model.\n", | |
"model = RandomForestClassifier(n_estimators=10, max_depth=10)\n", | |
"model.fit(X, y)" | |
], | |
"execution_count": 4, | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": [ | |
"RandomForestClassifier(bootstrap=True, class_weight=None, criterion='gini',\n", | |
" max_depth=10, max_features='auto', max_leaf_nodes=None,\n", | |
" min_impurity_decrease=0.0, min_impurity_split=None,\n", | |
" min_samples_leaf=1, min_samples_split=2,\n", | |
" min_weight_fraction_leaf=0.0, n_estimators=10,\n", | |
" n_jobs=None, oob_score=False, random_state=None,\n", | |
" verbose=0, warm_start=False)" | |
] | |
}, | |
"metadata": { | |
"tags": [] | |
}, | |
"execution_count": 4 | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "lNWLDmGF6Z6W", | |
"colab_type": "code", | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 34 | |
}, | |
"outputId": "7c0b469e-c89a-4933-8e46-52cf031856b1" | |
}, | |
"source": [ | |
"%%timeit -r 3\n", | |
"\n", | |
"# Time for scikit-learn.\n", | |
"model.predict(X)" | |
], | |
"execution_count": 5, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"text": [ | |
"1000 loops, best of 3: 1.27 ms per loop\n" | |
], | |
"name": "stdout" | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "E5725xaH6Z6f", | |
"colab_type": "code", | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 163 | |
}, | |
"outputId": "177a9903-8e0c-41ca-d6d2-6c599d88ffa7" | |
}, | |
"source": [ | |
"model = convert(model, 'pytorch', extra_config={\"tree_implementation\":\"gemm\"})" | |
], | |
"execution_count": 6, | |
"outputs": [ | |
{ | |
"output_type": "error", | |
"ename": "TypeError", | |
"evalue": "ignored", | |
"traceback": [ | |
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", | |
"\u001b[0;31mTypeError\u001b[0m Traceback (most recent call last)", | |
"\u001b[0;32m<ipython-input-6-cdc139877306>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m()\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mmodel\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mconvert\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmodel\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'pytorch'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mextra_config\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m{\u001b[0m\u001b[0;34m\"tree_implementation\"\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\"gemm\"\u001b[0m\u001b[0;34m}\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", | |
"\u001b[0;31mTypeError\u001b[0m: 'module' object is not callable" | |
] | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "SM4ih8g26Z6m", | |
"colab_type": "code", | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 34 | |
}, | |
"outputId": "36dce7f4-af29-4603-a87c-bed26c86564b" | |
}, | |
"source": [ | |
"%%timeit -r 3\n", | |
"\n", | |
"# Time for HB.\n", | |
"model.predict(X)" | |
], | |
"execution_count": 7, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"text": [ | |
"1000 loops, best of 3: 1.25 ms per loop\n" | |
], | |
"name": "stdout" | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "3jpof7vo6Z6u", | |
"colab_type": "code", | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 340 | |
}, | |
"outputId": "2676971a-4244-464d-8da3-709e1b1be975" | |
}, | |
"source": [ | |
"model.to('cuda')" | |
], | |
"execution_count": 8, | |
"outputs": [ | |
{ | |
"output_type": "error", | |
"ename": "MissingBackend", | |
"evalue": "ignored", | |
"traceback": [ | |
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", | |
"\u001b[0;31mMissingBackend\u001b[0m Traceback (most recent call last)", | |
"\u001b[0;32m<ipython-input-8-46563f6e5343>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m()\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mmodel\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mto\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'cuda'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", | |
"\u001b[0;32m/usr/local/lib/python3.6/dist-packages/hummingbird/ml/convert.py\u001b[0m in \u001b[0;36m_to_sklearn\u001b[0;34m(self, backend, test_input, extra_config)\u001b[0m\n\u001b[1;32m 36\u001b[0m \u001b[0mUtility\u001b[0m \u001b[0mfunction\u001b[0m \u001b[0mused\u001b[0m \u001b[0mto\u001b[0m \u001b[0mcall\u001b[0m \u001b[0mthe\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0mscikit\u001b[0m\u001b[0;34m-\u001b[0m\u001b[0mlearn\u001b[0m\u001b[0;34m*\u001b[0m \u001b[0mconverter\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 37\u001b[0m \"\"\"\n\u001b[0;32m---> 38\u001b[0;31m \u001b[0m_supported_backend_check\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mbackend\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 39\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 40\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mconvert_sklearn\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtest_input\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mextra_config\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", | |
"\u001b[0;32m/usr/local/lib/python3.6/dist-packages/hummingbird/ml/convert.py\u001b[0m in \u001b[0;36m_supported_backend_check\u001b[0;34m(backend)\u001b[0m\n\u001b[1;32m 29\u001b[0m \"\"\"\n\u001b[1;32m 30\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0mbackend\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mlower\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mbackend_map\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 31\u001b[0;31m \u001b[0;32mraise\u001b[0m \u001b[0mMissingBackend\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"Backend: {}\"\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mformat\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mbackend\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 32\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 33\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", | |
"\u001b[0;31mMissingBackend\u001b[0m: Backend: cuda\nIt usually means the backend is not currently supported.\nPlease check the spelling or fill an issue at https://github.com/microsoft/hummingbird.\n" | |
] | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "U_PbWOl86Z62", | |
"colab_type": "code", | |
"colab": {}, | |
"outputId": "cbd51fed-7dfe-4e1f-a104-1700665b8d03" | |
}, | |
"source": [ | |
"%%timeit -r 3\n", | |
"\n", | |
"# Time for HB GPU.\n", | |
"model.predict(X)" | |
], | |
"execution_count": 0, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"text": [ | |
"467 µs ± 2.29 µs per loop (mean ± std. dev. of 3 runs, 1000 loops each)\n" | |
], | |
"name": "stdout" | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "HrCBkNN_6Z6_", | |
"colab_type": "code", | |
"colab": {} | |
}, | |
"source": [ | |
"" | |
], | |
"execution_count": 0, | |
"outputs": [] | |
} | |
] | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment