Skip to content

Instantly share code, notes, and snippets.

@dayyass
Created May 19, 2022 20:07
Show Gist options
  • Save dayyass/d67a513ba8981ffa62014fb12562cc9f to your computer and use it in GitHub Desktop.
Save dayyass/d67a513ba8981ffa62014fb12562cc9f to your computer and use it in GitHub Desktop.
Convert sklearn logreg to torch neural network
Display the source blob
Display the rendered blob
Raw
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"name": "logreg_sklearn2torch.ipynb",
"provenance": []
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
},
"language_info": {
"name": "python"
},
"accelerator": "GPU"
},
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"id": "4O4zE-wMqFHq"
},
"outputs": [],
"source": [
"import torch\n",
"import numpy as np\n",
"\n",
"from sklearn.datasets import fetch_20newsgroups\n",
"from sklearn.feature_extraction.text import TfidfVectorizer\n",
"from sklearn.linear_model import LogisticRegression\n",
"from sklearn.pipeline import Pipeline"
]
},
{
"cell_type": "code",
"source": [
"X, y = fetch_20newsgroups(return_X_y=True)"
],
"metadata": {
"id": "7v-MapedrJG9"
},
"execution_count": 2,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"### sklearn"
],
"metadata": {
"id": "rcq-E1D65W-9"
}
},
{
"cell_type": "code",
"source": [
"pipe = Pipeline(\n",
" [\n",
" ('tf-idf', TfidfVectorizer(min_df=0.01)),\n",
" ('logreg', LogisticRegression()),\n",
" ]\n",
")"
],
"metadata": {
"id": "kwgqAtK1rvXl"
},
"execution_count": 3,
"outputs": []
},
{
"cell_type": "code",
"source": [
"pipe.fit(X, y)"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "oWxNcgWJsyTX",
"outputId": "cd713904-d1d7-4bc6-8acd-d4ede365b978"
},
"execution_count": 4,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"Pipeline(steps=[('tf-idf', TfidfVectorizer(min_df=0.01)),\n",
" ('logreg', LogisticRegression())])"
]
},
"metadata": {},
"execution_count": 4
}
]
},
{
"cell_type": "code",
"source": [
"pred_cpu = pipe.predict_proba(X)"
],
"metadata": {
"id": "Et8T_rl20VtV"
},
"execution_count": 5,
"outputs": []
},
{
"cell_type": "code",
"source": [
"class LogReg(torch.nn.Module):\n",
" def __init__(self, pipe):\n",
" super().__init__()\n",
" self.tfidf = pipe['tf-idf']\n",
"\n",
" n_classes, n_features = pipe['logreg'].coef_.shape\n",
" self.linear = torch.nn.Linear(n_features, n_classes)\n",
"\n",
" with torch.no_grad():\n",
" self.linear.weight.copy_(torch.Tensor(pipe['logreg'].coef_))\n",
" self.linear.bias.copy_(torch.Tensor(pipe['logreg'].intercept_))\n",
" \n",
" self.softmax = torch.nn.Softmax(dim=1)\n",
"\n",
" def forward(self, emb):\n",
" with torch.no_grad():\n",
" return self.softmax(self.linear(emb))\n",
"\n",
" def predict_proba(self, texts):\n",
" emb = torch.Tensor(\n",
" self.tfidf.transform(texts).toarray(),\n",
" ).to(self.linear.weight.device)\n",
" return self(emb).cpu().numpy()"
],
"metadata": {
"id": "qRta49GpxMcK"
},
"execution_count": 6,
"outputs": []
},
{
"cell_type": "code",
"source": [
"model = LogReg(pipe).to('cuda')"
],
"metadata": {
"id": "oaich7NHv2pB"
},
"execution_count": 7,
"outputs": []
},
{
"cell_type": "code",
"source": [
"pred_gpu = model.predict_proba(X)"
],
"metadata": {
"id": "4JDaWZAizvrC"
},
"execution_count": 8,
"outputs": []
},
{
"cell_type": "code",
"source": [
"np.allclose(pred_cpu, pred_gpu)"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "WAWqfdze34ie",
"outputId": "c245fda3-b7d3-4686-e1f7-d703983666ee"
},
"execution_count": 9,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"True"
]
},
"metadata": {},
"execution_count": 9
}
]
},
{
"cell_type": "code",
"source": [
"%%timeit\n",
"# sklearn\n",
"pipe.predict_proba(X)"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "Qyn086eh4Rto",
"outputId": "d941df61-7327-4053-878b-5c200b805f2d"
},
"execution_count": 10,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"1 loop, best of 5: 1.98 s per loop\n"
]
}
]
},
{
"cell_type": "code",
"source": [
"%%timeit\n",
"# torch\n",
"model.predict_proba(X)"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "2ERIGE2A5mvm",
"outputId": "0baafb91-abd7-46d4-982b-d57dc2bbbe91"
},
"execution_count": 11,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"1 loop, best of 5: 2.09 s per loop\n"
]
}
]
},
{
"cell_type": "code",
"source": [
""
],
"metadata": {
"id": "fs27O7NV5t2R"
},
"execution_count": 11,
"outputs": []
}
]
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment