Skip to content

Instantly share code, notes, and snippets.

@avidale
Last active May 25, 2023 21:22
Show Gist options
  • Save avidale/a640f7a8e353d9efdd79385e277caef1 to your computer and use it in GitHub Desktop.
Save avidale/a640f7a8e353d9efdd79385e277caef1 to your computer and use it in GitHub Desktop.
Logistic regression initialization.ipynb
Display the source blob
Display the rendered blob
Raw
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"name": "Logistic regression initialization.ipynb",
"provenance": [],
"collapsed_sections": [],
"authorship_tag": "ABX9TyPAUZSyCkrzK+/4xnXOkJpj",
"include_colab_link": true
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
}
},
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
},
"source": [
"<a href=\"https://colab.research.google.com/gist/avidale/a640f7a8e353d9efdd79385e277caef1/init_logreg.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "TNhcEdNxBn14"
},
"source": [
"The problem: logistic regression, one of the simplest and fastest classifiers, still trains relatively slowly, because it uses gradient descent for training.\r\n",
"\r\n",
"The speed with which gradient descent converges depends on initialization: if we start descent from a good initial point, we will have a shorter way to go.\r\n",
"\r\n",
"So what is a way to initialize a logistic regression which is close to an optimal solution and at the same time fast to converge? Our suggestion is that linear regression is a good candidate. "
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "nDmx2NVpoC8j"
},
"source": [
"Our idea is to initialize logistic regression $f$ with such parameters $\\alpha, \\beta$ that a linear regression $g$ trained on the same data is tangent to the logistic regression in the center of the dataset $\\bar{X}, \\bar{Y}$ (linear regression always goes through it, so that $g(\\bar{X})=\\bar{Y}$).\r\n",
"\r\n",
"Linear regression: $$g(X) = a + b X$$\r\n",
"Logistic regression: $$f(X) = \\sigma(\\alpha + \\beta X)$$\r\n",
"where $\\sigma(t) = \\frac{1}{1+e^{-t}}$.\r\n",
"\r\n",
"\"Tangent\" means \"values are equal and first derivatives are equal at this point\", so we want\r\n",
"$$ g(\\bar{X}) = f(\\bar{X})$$\r\n",
"$$ g'(\\bar{X}) = f'(\\bar{X})$$\r\n",
"\r\n",
"Or equivalently\r\n",
"$$ \\bar{Y}= f(\\bar{X})$$\r\n",
"$$ b = f(\\bar{X}) (1-f(\\bar{X}))\\beta$$\r\n",
"\r\n",
"From the last equation we can immediately express the slope coefficient\r\n",
"$$\\beta = \\frac{1}{\\bar{Y}(1-\\bar{Y})}b$$\r\n",
"We can use this value to calculate the intercept \r\n",
"$$ \\alpha = \\sigma^{-1}(\\bar{Y}) - \\beta \\bar{X}$$\r\n",
"\r\n",
"The inverse logistic function $\\sigma^{-1}(t) = \\log\\frac{t}{1-t}$ is called logit. "
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "cAUoZjlwr3tp"
},
"source": [
"We start with a small generared dataset. To train linear regression, we use Ridge, so that we don't have to worry about multicollinearity. \r\n",
"\r\n"
]
},
{
"cell_type": "code",
"metadata": {
"id": "88UUviBKkMyq"
},
"source": [
"import numpy as np\r\n",
"import matplotlib.pyplot as plt\r\n",
"\r\n",
"from sklearn import datasets\r\n",
"from sklearn.linear_model import LogisticRegression, Ridge\r\n",
"from sklearn.model_selection import train_test_split\r\n",
"\r\n",
"def logit(x):\r\n",
" return np.log(x/(1-x))\r\n",
"\r\n",
"\r\n",
"# Create dataset of classification task with many redundant and few\r\n",
"# informative features\r\n",
"X, y = datasets.make_classification(n_samples=100, n_features=1, n_informative=1, n_redundant=0, random_state=42, n_clusters_per_class=1)\r\n"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 279
},
"id": "U0Q7KW5Hkg6b",
"outputId": "b9a91372-3deb-4274-b559-52759cd29020"
},
"source": [
"plt.scatter(X[:,0], y, s=3, c='k')\r\n",
"g = np.linspace(X[:,0].min(), X[:,0].max())\r\n",
"linreg = Ridge(alpha=0.01).fit(X, y)\r\n",
"plt.plot(g, linreg.predict(g[:, np.newaxis]))\r\n",
"log_init = LogisticRegression()\r\n",
"ym = y.mean()\r\n",
"log_init.coef_ = linreg.coef_[np.newaxis, :] / (ym * (1-ym))\r\n",
"log_init.intercept_ = logit(ym)[np.newaxis] - np.dot(X.mean(axis=0), log_init.coef_)\r\n",
"log_init.classes_ = np.array([0, 1])\r\n",
"plt.plot(g, log_init.predict_proba(g[:, np.newaxis])[:, 1])\r\n",
"log_final = LogisticRegression().fit(X, y)\r\n",
"plt.plot(g, log_final.predict_proba(g[:, np.newaxis])[:, 1])\r\n",
"\r\n",
"plt.xlabel('X'); plt.ylabel('y'); plt.legend([ 'linear model', 'initial logistic model', 'final logistic model', 'data']);"
],
"execution_count": null,
"outputs": [
{
"output_type": "display_data",
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"tags": [],
"needs_background": "light"
}
}
]
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "vGolOtBPkztp",
"outputId": "33a185db-0e74-4402-bd61-210f19bb07df"
},
"source": [
"l1 = LogisticRegression().fit(X, y)\r\n",
"print(l1.n_iter_, 'iterations with default initialization')\r\n",
"l2 = LogisticRegression(warm_start=True)\r\n",
"ym = y.mean()\r\n",
"l2.coef_ = linreg.coef_[np.newaxis, :] / (ym * (1-ym))\r\n",
"l2.intercept_ = logit(ym)[np.newaxis] - np.dot(X.mean(axis=0), l2.coef_)\r\n",
"l2.classes_ = np.array([0, 1])\r\n",
"l2.fit(X, y)\r\n",
"print(l2.n_iter_, 'iterations with initialization from linear model')"
],
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"text": [
"[9] iterations with default initialization\n",
"[6] iterations with initialization from linear model\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "TQfWnaqSoTsA"
},
"source": [
"On a larger dataset, it turns out that smart initialization helped to make training almost 3 times faster. "
]
},
{
"cell_type": "code",
"metadata": {
"id": "vOzDc1_CnTUF"
},
"source": [
"X, y = datasets.make_classification(n_samples=10_000, n_features=1000, n_informative=200, n_redundant=400, random_state=42, n_clusters_per_class=10)"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "PBV7L29EnbT8",
"outputId": "78cc64ec-019b-409c-80ae-81fc6703eec4"
},
"source": [
"l1 = LogisticRegression(max_iter=1000).fit(X, y)\r\n",
"print(l1.n_iter_, 'iterations with default initialization')\r\n",
"\r\n",
"linreg = Ridge(alpha=0.01).fit(X, y)\r\n",
"l2 = LogisticRegression(warm_start=True, max_iter=1000)\r\n",
"ym = y.mean()\r\n",
"l2.coef_ = linreg.coef_[np.newaxis, :] / (ym * (1-ym))\r\n",
"l2.intercept_ = logit(ym)[np.newaxis] - np.dot(X.mean(axis=0), l2.coef_[0])\r\n",
"l2.classes_ = np.array([0, 1])\r\n",
"l2.fit(X, y)\r\n",
"print(l2.n_iter_, 'iterations with initialization from linear model')"
],
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"text": [
"[292] iterations with default initialization\n",
"[107] iterations with initialization from linear model\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 279
},
"id": "-CR493ReoR4G",
"outputId": "f46fe834-1815-4941-99fc-be22ab37d802"
},
"source": [
"plt.scatter(linreg.coef_, l2.coef_);\r\n",
"plt.xlabel('linear regression coefficients'); plt.ylabel('logistic model coefficients');"
],
"execution_count": null,
"outputs": [
{
"output_type": "display_data",
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAZAAAAEGCAYAAABLgMOSAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAAgAElEQVR4nO3de5gcdZ3v8fcnw4ATlEwCAWEggoAgHBR1VlB0F+SOAllAwEUNinLco+uqaw5BORAuK8F4Wx9c16C4EVBu4jAIaww3LwhCwgAhSCQitwblkoSLCRCS7/mjqpNO091T09PV3TPzeT1PP1NVXV31q8xkvvO7fX+KCMzMzIZqXKsLYGZmI5MDiJmZ1cUBxMzM6uIAYmZmdXEAMTOzumzU6gI00xZbbBHbb799q4thZjaiLFy48OmImFx+fEwFkO23354FCxa0uhhmZiOKpIcrHXcTlpmZ1aWlAUTSIZKWSFoqaUaF9zeRdFn6/u8lbV/y3lsk3SppsaRFkl7TzLKbmY11LQsgkjqA7wCHArsBH5K0W9lpJwHLI2In4JvAeelnNwIuBj4VEbsD+wKrm1R0MzOjtTWQdwJLI+LBiHgZuBQ4suycI4G56faVwP6SBBwE3BMRdwNExDMRsaZJ5TYzM1obQHqAR0v2H0uPVTwnIl4BngU2B94EhKR5ku6U9H+bUF4zMysxUkdhbQS8B/g7YCVwg6SFEXFD+YmSTgZOBpgyZUpTC2lm1kp9AwVmz1vC4ytWsU13F9MP3oWpbyv/O71+rayBFIDtSva3TY9VPCft95gAPENSW/l1RDwdESuB64C3V7pJRMyJiN6I6J08+VXDmM3MRqW+gQKnXrWIwopVBFBYsYpTr1pE30D5r9n6tTKA3AHsLGkHSRsDxwP9Zef0A9PS7WOAGyPJPz8P2EPS+DSw/ANwX5PKbWbW9mbPW8Kq1Rt2Da9avYbZ85Y07B4ta8KKiFckfYYkGHQAF0bEYklnAQsioh/4AXCRpKXAMpIgQ0Qsl/QNkiAUwHURcW1LHsTMrA09vmLVkI7Xo6V9IBFxHUnzU+mx00u2XwQ+WOWzF5MM5TUzszLbdHdRqBAstunuatg9PBPdzGwUmn7wLnR1dmxwrKuzg+kH79Kwe4zUUVhmZlZDcbRVnqOwHEDMzEapqW/raWjAKOcAYmY2AuQ9p6MeDiBmZm2uOKejOCy3OKcDaGkQcSe6mVmba8acjnq4BmJm1kYqNVU1Y05HPRxAzMzaRLWmqu7xnSxf+eoVKxo5p6MebsIyM2sTM/sXV2yqiiD3OR31cAAxM2sDfQMFVqyqvC7es6tWc+5Re9DT3YWAnu4uzj1qD4/CMjMzanaIb9Pdlfucjnq4BmJm1gYq5a0qanVTVTUOIGZmLdY3UEBV3ps4vrPtah5FDiBmZi02e94SosJxAWccvnuzi5OZ+0DMzJqgViqSavM5gtbONB+MayBmZjkbbHnZavM5elo8z2MwDiBmZjkbLBVJM9buyIObsMzMctQ3UKg6wqrYdNWMtTvy4ABiZpaT0/oWccltj1R9v7Tpqh3neQzGAcTMrA6Drc/RN1DgktseqTi6CkZGE9VgHEDMzIaob6DA9CvuZvXaJDwUVqxi+hV3Axs2R1ULHkBbpCIZLneim5kN0cz+xeuCR9HqtcHM/sXr9mvNLO9JU5OMdA4gZmZDVC3pYenxDlWbW96+qUmGygHEzGwIinM3BrMmqjdgjYbaB7S4D0TSIcB/AB3A9yNiVtn7mwA/At4BPAMcFxEPlbw/BbgPmBkRX2tWuc1s7CjvLF/+t5eqnjtxfOe67Z7urorNWO0+OXAoWlYDkdQBfAc4FNgN+JCk3cpOOwlYHhE7Ad8Ezit7/xvA/+RdVjMbmyrNIF+5em3V89//lq3XbY/UyYFD0comrHcCSyPiwYh4GbgUOLLsnCOBuen2lcD+UtKwKGkq8GdgMWZmOag0g7yWm+5/at321Lf1tOUiUI3UyiasHuDRkv3HgL2qnRMRr0h6Fthc0ovAKcCBwBdr3UTSycDJAFOmTGlMyc1sTKiW5DDr+SNxcuBQjNRO9JnANyPihcFOjIg5EdEbEb2TJ0/Ov2RmNuL1DRTYZ9aNNedxVFItKeJo1coaSAHYrmR/2/RYpXMek7QRMIGkM30v4BhJXwW6gbWSXoyI8/MvtpmNVn0DBWb2L646TLeW0da/kUUrA8gdwM6SdiAJFMcD/1R2Tj8wDbgVOAa4MSICeG/xBEkzgRccPMxsOIod5kPp8+iQWBNBzwhJfthoLQsgaZ/GZ4B5JMN4L4yIxZLOAhZERD/wA+AiSUuBZSRBxsys4YbSYd7ZIWYf89YxFzDKtXQeSERcB1xXduz0ku0XgQ8Oco2ZuRTOzMaUwTrMOyTWRoyYVOvNMKQAImkc8NqIeC6n8piZtcQ2VSb+QdK/MdqG4DbCoKOwJP1Y0maSNgXuBe6TND3/opmZNVZxdNUOM65ln1k3bpCWZL9dJ1Mpe9XE8Z0OHlVkqYHsFhHPSTqBZNb3DGAhMDvXkpmZNVB5J3lxXfKiny4svGrY7j47TuKST76riaUcWbIEkE5JncBU4PyIWK0aWSbNzNrRYOuSV+pA/92fltE3UHDto4osEwm/BzwEbAr8WtIbgGfzLJSZWaNV698orFhVtQM9YF2AsVfLEkCuiYieiDgsnYPxCPDxnMtlZtYQfQMF9jzzl1Xf75BqziAfajqTsSRLAPlp6U4aRC7NpzhmZkNXrXP8hAtu5XOX3VVzZvmaCKYfvEvFDnQYe+lJhqJqH4ikXYHdgQmSjip5azPgNXkXzMwsi0qd45+/7C4+d9ldmT5fXF52wcPLuOS2RzboSB+L6UmGolYn+i7AB0hyTR1ecvx54JN5FsrMLKtKneNZkyCK9cvLnjN1D3rfMGmDxaM8YbC2qgEkIq4Grpb0roi4tYllMjOrqXSVwKFmzC0ScMLeUzYIEKM9/XqjZRnGu1TSl4DtS8+PCHekm1lT9Q0UOPOaxSxfOfRsuaUmju/kjMN3d7AYpiwB5GrgN8D1QPY0lWZmw1C+Fvl+u07mpwsLQ8qWW8nOW27K/C/s25hCjnFZAsj4iDgl95KYmaUqdYxffNsjDbn2yperr2luQ5NlGO/PJR2We0nMzFJDXYt8KDyvo3GyBJB/JQkiL0p6TtLzkpyN18xyk+cvec/raJxBm7Ai4nXNKIiZGcBpfYvqHlk1GM/raKxBA4iSzIknADtExNmStgO2jojbcy+dmY0ZSb/HPaxa3bg+iu6uTiRYsXK153XkIEsn+n8Ca4H3AWcDLwDfAf4ux3KZ2RiyvtN8+MGju6uTmUd4iG4zZAkge0XE2yUNAETEckkb51wuMxvlisN0q2XJrdemm2zk4NEkWQLIakkdpNkBJE0mqZGYmdXltL5Fr8o71SgeZdU8WUZhfRv4GbClpH8Hfgt8JddSmdmo1TdQ4OKcggd4lFUzZRmFdYmkhcD+JOljpkbEH3IvmZmNSjP7FzfkOuOUrOWxeu36UORRVs1VK537Zula6JOAJ4GflLw3KSKWNaOAZjaynda3iJ/8/lHWRCCyZ8qtpauzg3OP2gPA2XNbqFYN5Mck6dwXsuH3vPgz8Mbh3lzSIcB/AB3A9yNiVtn7mwA/At4BPAMcFxEPSToQmAVsDLwMTI+IG4dbHjNrrBMuuJVb/rT+b81GBI8OiXOP2mNdoHDAaJ1a6dw/kH7dIY8bpx3z3wEOBB4D7pDUHxH3lZx2ErA8InaSdDxwHnAc8DRweEQ8Lul/AfMA/xSZtZG+gcIGwaMRijUPB432MGgnuqR/lDShZL9b0tQG3PudwNKIeDAiXiZZJvfIsnOOBOam21cC+0tSRAxExOPp8cVAV1pbMbM20ai+jqKJ4zsdPNpMlmG8Z0TEz4o7EbFC0hlA3zDv3QM8WrL/GLBXtXMi4hVJzwKbk9RAio4G7oyIlyrdRNLJwMkAU6ZMGWaRzaxcedr1Yj9ErXXIh6KrcxznHvUWB442lCWAVKqlZPlc7iTtTtKsdVC1cyJiDjAHoLe3N6+Rg2ZjUt9AgelX3s3qNcl/rcKKVXzusrv4/OXZ1iOvxYs+tb8sgWCBpG+Q9FcAfJqkY324CsB2JfvbpscqnfOYpI2ACSSd6UjalmR+ykcj4k8NKI+ZDdGZ1yxeFzxKRQP+VBs4verfhdYmskwk/BeSkU6Xpa+XSILIcN0B7CxphzQ1yvFAf9k5/cC0dPsY4MaICEndwLXAjIi4pQFlMbM6DHdp2Wp6PBlwRMgykfBvwIxG3zjt0/gMyQiqDuDCiFgs6SxgQUT0Az8ALpK0FFhGEmQAPgPsBJwu6fT02EER8WSjy2lmG+obKDCzf3HD+jjKeTLgyKGoUteU9K2I+Jyka6gwfDsijsi7cI3W29sbCxYsaHUxzEasvoEC06+4e4PZ343QIVgbeDJgm5K0MCJ6y4/XqoH8KP36tXyKZGbtrDRbbofEmoh1XxulQ+JDe23HOVP3aNg1rXlqBZDZJPmvDouIU5pUHjNrA+vX50jWJS8GjUYFj28dt6drGaNArQCytaR3A0dIupQkhck6EXFnriUzs5aZPW/JuuDRaOM7xzl4jBK1AsjpwP8jGV77dTYMIEGyQqGZjSKliQ/z0NkhvnLUW3K5tjVfrQDyREQcKun0iDiraSUys6bLYz3yoq7Ocby4eq07yEehWgHk2yRZcKcCDiBmo1ReI6sgaa667+xDG35daw+1AshqSXOAHknfLn8zIj6bX7HMrFlm9i/OJXgAudRorH3UCiAfAA4ADqYxqUvMrE2UJkDMM0Gcl5cd3WqtB/I0cKmkP0TE3U0sk5nlqHyIbiPss+Mk7nzk2Q2u6Rnlo1+WXFirJN0g6V4ASW+RdFrO5TKznOQxRPeST76Lc4/ag57uLkSSy8prd4x+WbLxXgBMB74HEBH3SPoxcE6eBTOzfDy+YlVDr/fhvZN1dqa+rccBY4zJEkDGR8Tt0gbzCF/JqTxm1kClfR3d4zt5cfWahvZ57LPjJKchGcOyBJCnJe1ImlBR0jHAE7mWysyGrbyvYzip1/fZcRKLH39+XQZeL/ZkkC2AfJpkRb9dJRWAPwMn5FoqMxu2RvV1jO8cxyWffFcDSmSjTZb1QB4EDpC0KTAuIp7Pv1hmVq9Gr9fhuRxWzaABRNIE4Azg79P9XwFnRcSzOZfNzIYoj1nlnsth1WQZxnsh8DxwbPp6DvhhnoUys6HrGyjw+cvvavis8v12ndzQ69nokaUPZMeIOLpk/0xJd+VVIDMbmr6BAl+66h5W5tTUdNP9T+VyXRv5sgSQVZLeExG/BZC0D9DYgeRmVpfT+hZx8W2PDPs6XZ0dVTvcGz1vxEaPLAHkn4G5aV8IwHLgxNxKZGYVlc7p2Ka7i/12ncwlDQgePWma9eLyteXcB2LVZBmFdRfwVkmbpfvP5V4qM9tAeed4YcWqhtQ8BNwyY/3acOU5spzPymoZtBNd0lckdUfEcxHxnKSJkpzGxKyJ8kq5Xlq7mPq2HuezsiHJ0oR1aER8qbgTEcslHQY4oaJZkzRqTkepSrUL57OyocgyjLdD0ibFHUldwCY1zs9M0iGSlkhaKmlGhfc3kXRZ+v7vJW1f8t6p6fElkg5uRHnM2k3fQIF9Zt3Y8Ou6dmGNkKUGcglwg6Ti3I+PAXOHe2NJHcB3gAOBx4A7JPVHxH0lp50ELI+InSQdD5wHHCdpN+B4YHdgG+B6SW+KiMbmqDZror6BAmdes3hdzqquznG8vCZY04Cmqw6Jrx/7VgcMa6hBayARcR5J6vY3p6+zI+KrDbj3O4GlEfFgRLwMXAocWXbOkawPVlcC+ytJC3wkcGlEvBQRfwaWptczG5H6BgpMv/LuDRIerlq9tiHBo3Ocg4flI0sNhIj4BfCLBt+7B3i0ZP8xYK9q50TEK5KeBTZPj99W9tmK/zsknQycDDBlypSGFNys0WbPW8LqNY3vJO/u6mTmEc6aa/nIFEBGsoiYQ5JNmN7e3jyXfzYbkuK8jkpzL4arp7trg+G5ZnloZQApANuV7G+bHqt0zmOSNgImAM9k/KxZ2zqtbxGX3PZIQxd3KvLcDWuWLKOw8nIHsLOkHSRtTNIp3l92Tj8wLd0+BrgxIiI9fnw6SmsHYGfg9iaV22xY+gYKuQUPj66yZqpaA5G0CCr+jAuIiHjLcG6c9ml8BpgHdAAXRsRiSWcBCyKiH/gBcJGkpcAykiBDet7lwH0ky+t+2iOwbKSYPW9JQ4OHR1hZqyj5g77CG9Iban0wIh7OpUQ56u3tjQULFrS6GDYGleaxamTwEPDN4/Z08LBcSVoYEb3lx6vWQEoDRBpMdo6I69OJhKO+892sEZJ1ye/JZVU/ASfsPcXBw1omy4qEnyQZBjsJ2JGkw/q/gP3zLZrZyNY3UOALl99Fo1NYiSSH1fSDd3HwsJbKUpP4NMkkvd8DRMQDkrbMtVRmI1Sj1yMvNQ74hpurrI1kCSAvRcTLyQRwSIfTej6FWZk81iMvEg4e1n6yDOP9laQvAV2SDgSuAK7Jt1hmI8/seUtyCR7gjnJrT1kCyAzgKWAR8L+B63Aqd7NXyWPpVwEfdke5taksKxKuBS5IX2ZWojyD7nA8NOv9r1q21h3l1s7qmUgIwHAnEpqNZI3uLJ84vhPwgk42stSqgXwg/frp9OtF6dcP4050G8NO61vUkPXIizo7xBmH796w65k1y6ATCSUdGBFvK3nrFEl3kvSNmI0pjQ4eTrduI1mWYbyStE9E3JLuvJvWJmE0a4m+gULDgocE3zzWI6tsZMsSQE4CLpQ0gWRQyHLg47mWyqyNNHrdjs5xYvYHnfzQRr4so7AWAm9NAwgR8WzupTJrE0kuq0WsWt2YZM8Tx3dyxuFusrLRIUsurAnAGcDfp/u/As5yILGxYPa8JQ0JHj0ekmujUJYmrAuBe4Fj0/2PAD8EjsqrUGbNVG3uRd9AoSHNVt1dnV5e1kalquuBrDtBuisi9hzs2Ejg9UCsXKUmqs4O0TlOrGxACnb3d9hoUG09kCyjqVZJek/JhfYBGp+zwawFKjVRrV4TDQkePd1dDh42qmVpwvpnYG7JKKxlwIl5FsqsWfLIXwVJ8HCzlY12WUZh3UUyCmuzdP+53EtllqPSPo9xEmsGacYdqq7ODqYfvEtDr2nWjrKMwuoGPgpsD2xUXBckIj6ba8nMclDe59Ho4OHRVjaWZGnCug64jSSde+MXdjZrokYNyy3X2SFmH+P+DhtbsgSQ10TEF3IviVnOGjUsd3znOI56x7bcdP9TTrtuY1qWAHKRpE8CPwdeKh6MiGW5lcqsQRqdhmTnLTdl/hf2bci1zEa6LMN4XwZmA7cCC9PXsCZTSJokab6kB9KvE6ucNy095wFJ09Jj4yVdK+l+SYslzRpOWWz0Kq5RPpzg0ZH2+XVIfHjvKQ4eZiWy1ED+DdgpIp5u4H1nADdExCxJM9L9U0pPkDSJJIVKL8n6Iwsl9ZPUgr4WETdJ2hi4QdKhEfE/DSyfjVClI6xg+AvX/Oncw4ZfKLNRKksNZCmwssH3PRKYm27PBaZWOOdgYH5ELIuI5cB84JCIWBkRNwFExMvAncC2DS6fjUDFEVaFFasIhh88irUPM6ssSw3kb8Bdkm5iwz6Q4Qzj3Soinki3/wJsVeGcHuDRkv3H0mPrpEOMDwf+o9qNJJ0MnAwwZcqUYRTZ2lWj+zmKPrTXdg29ntlokyWA9KWvIZF0PfD6Cm99uXQnIkLSkP9YlLQR8BPg2xHxYLXzImIOMAeSXFhDvY+1t0anWwcYJ/invaZwztQ9GnZNs9Eoy0z0uYOdU+VzB1R7T9JfJW0dEU9I2hp4ssJpBWDfkv1tgZtL9ucAD0TEt+opn40OjZzX4UmAZkOTpQaSh35gGjAr/Xp1hXPmAV8pGaF1EHAqgKRzgAnAJ/IvqrWz4TZbOWiY1a9VAWQWcLmkk4CHSdcakdQLfCoiPhERyySdDdyRfuas9Ni2JM1g9wN3pqlVzo+I7zf9KazlOurMZeWZ42bD15IAEhHPAPtXOL6AklpFRFxIsqBV6TmPkWQFtjGub6BQdy4rBw+z4cuSTHE+8MGIWJHuTwQujYiD8y6cWSV9AwVm9i9mxarVdX2+p7vLwcOsAbLUQLYoBg+AiFguacscy2RW1XBHXTnVulnjZAkgayVNiYhHACS9geHP0TIbkuHM9RgnWBvuMDdrtCwB5MvAbyX9iqTv4b2kE/PMmqGeWsfE8Z2ccfjuDhZmOcoyD+QXkt4O7J0e+lyD82KZ1TTUuR4Tx3cycPpBOZbIzKBGAJG0a0TcnwYPgMfTr1PSJq078y+ejRWlSRC36e5iv10nr1tvYyjtpV2dHZxx+O65ldPM1qtVA/kCSVPV1yu8F8D7cimRjTnlTVSFFau4+LZHMn22u6sTCVasXO2FncyarGoAiYhiP8ehEfFi6XuSXpNrqWzM6Bso8G+X3z3k+RxdnR2ce9QeDhZmLZSlE/13wNszHDPLrN65HALXNMzaRK0+kNeTpE/vkvQ21s/+3gwY34Sy2ShV71yOnu4ubpnhllOzdlGrBnIwcCJJFtyvsz6APA98Kd9i2WhWTwZdTwA0az+1+kDmAnMlHR0RP21imWwUKR1dNSHt8F6+cmjNVp4AaNaesvSBbCtpM5KaxwUkfR8zIuKXuZbMRrzypqp6c1e52cqsPWVZE/3jEfEcyXocmwMfIUnHblZTIxZ76u7qbFBpzKzRsgSQYt/HYcCPImIxTqduGTw+zMWeOseJmUd4UqBZu8rShLVQ0i+BHYBTJb0OWJtvsWykKu3zGJdxsSfx6uyc3V2dzDzCuazM2lmWAHISsCfwYESslLQ58LF8i2UjUXmfR9bJgUHSUV5MY+IOc7ORYdBcWCTBA+CN6fKxZhXV2+fRIbmj3GwEci4sa4i+gUJda3VA9pqKmbWXQXNhRcR+zSuOtavSBZ060r6N4vwMgFOvWlT3tXu6uxpVTDNroixroh9V4fCzwKKIeLLxRbJ2URo0Sju6izWGwopVnHrVIjbZaFzVpqvODkHA6rWVaxmeYW42cmXtRH8XcFO6vy+wENhB0lkRcVFOZbMW6hsoMP2Ku9f94q/WyLRq9Zqa/R6zj3lr8jUdmdU9vpMIeHaV06+bjXRZAshGwJsj4q8AkrYCfgTsBfwacAAZhWb2L65aa8iqp7trXXBwkDAbfbJMJNyuGDxST6bHlgH15aawtjeUtCPjO8fR1dmxwTE3TZmNflkCyM2Sfi5pmqRpQH96bFNgRT03lTRJ0nxJD6RfJ1Y5b1p6zgPpvcvf75d0bz1lsMZZvTY4+h099HR3IZKahxd7Mhv9sjRhfRo4CnhPuj8X+GlEBFDvCK0ZwA0RMUvSjHT/lNITJE0CzgB6SZrgF0rqj4jl6ftHAS/UeX+rothxPhSr1wQ33f+U53KYjTGDBpCICEm/BV4m+UV+exo8huNIks54SALSzZQFEJL1SOanTWVImg8cAvxE0mtZP0/l8mGWZUwrTT3SPb6TF158pa6+j+HmvTKzkWfQJixJxwK3A8cAxwK/l3TMMO+7VUQ8kW7/Bdiqwjk9wKMl+4+lxwDOJpnguHKwG0k6WdICSQueeuqpYRR59CmmHimsWEWQrNNRb8f5Np7LYTbmZGnC+jLwd8U5H5ImA9cDV9b6kKTrgddXud46aQ0n828tSXsCO0bE5yVtP9j5ETEHmAPQ29vrKc8lGpFuHdxhbjZWZQkg48omDD5DhppLRBxQ7T1Jf5W0dUQ8IWlrkpFd5Qqsb+aCZGndm0nmpPRKeoik/FtKujki9sWGJGuzU0eNrLpeLdBs7MoyCusXkuZJOlHSicC1wHXDvG8/UBxVNQ24usI584CDJE1MR2kdBMyLiO9GxDYRsT1Jx/4fHTzqk6XZqauzgw/ttV3FYbrfOm5PbpnxPgcPszEqS01iOkkT0FvS15yIKO/wHqpZwIGSHgAOSPeR1Cvp++l9l5H0ddyRvs4qdqhbY0w/eJeaK4MVh+OeM3UPzj1qDw/TNbMNaPgDqkaO3t7eWLBgQauL0Va2n3FtxeMC/jzr/c0tjJm1JUkLI6K3/Hit9UCep3IKJJH0fW/WwPJZi3R3dVacdT7Ba5Gb2SBqpXN/XTMLYq1RbY0wrx1mZoPJ0oluo9iKlZVzXlU7bmZW5AAyxlUbieWJgWY2GAeQMaBvoMA+s25khxnXss+sG+kbKKx7b/rBuziTrpnVJctEQhthSvNbTejq5G8vv8LqNRuuIgjJGh3FobjF873Ik5ll5WG8o0wxv9VgKUp6urucPdfMMhnyMF5rb6W1jNJaQ9b8Vs6ea2bD5QAyApXXMkqbpbIGBneSm9lwuRN9BKpUy1i1eg2z5y3JnN/KneRmNlwOICNQtVrG4ytWVRxV1TlOTBzf6TxWZtZQbsIagbbp7qJQIYhs093lUVVm1jQOICNA+bKzL1XoJC9tliodnmtmlhcHkDbWN1BgZv/iDZIdLq+QYqS7q5OZR+zuoGFmTeUA0qayzucA2HSTjRw8zKzp3InepoayXrnndJhZKziAtKmhBAXP6TCzVnAAaVNZg4LndJhZqziAtKlK8znKeU6HmbWSO9HbVOl8jsKKVXRIrImgx/M6zKxNOIC0Mc/nMLN25iYsMzOriwOImZnVpSUBRNIkSfMlPZB+nVjlvGnpOQ9ImlZyfGNJcyT9UdL9ko5uXunNzAxaVwOZAdwQETsDN6T7G5A0CTgD2At4J3BGSaD5MvBkRLwJ2A34VVNKbWZm67SqE/1IYN90ey5wM3BK2TkHA/MjYhmApPnAIcBPgI8DuwJExFrg6dxLPIhqKwSamY1WrQogW0XEE+n2X4CtKpzTAzxasv8Y0COpO90/W9K+wJ+Az0TEXyvdSNLJwMkAU6ZMaUDRX63aCoELHl7GTfc/5aBiZqNSbk1Ykq6XdG+F15Gl50VEADGES28EbAv8LiLeDtwKfK3ayRExJyJ6I6J38uTJ9TzKoKqtEHjxbY9QWLGKYH1Q6Rso5FIGM7Nmy60GEhEHVHtP0l8lbR0RT0jaGniywmkF1jdzQRI0bjV4H+sAAAriSURBVAaeAVYCV6XHrwBOakSZ69E3UKi4uFMlxWVnXQsxs9GgVZ3o/UBxVNU04OoK58wDDpI0Me08PwiYl9ZYrmF9cNkfuC/f4lZWbLoaCmfONbPRolV9ILOAyyWdBDwMHAsgqRf4VER8IiKWSTobuCP9zFnFDnWSDveLJH0LeAr4WDMLX+wwz1rzKOXMuWY2WrQkgETEMyQ1h/LjC4BPlOxfCFxY4byHgb/Ps4xF5aOr9tt1Mj9dWMi8VkcpZ841s9HEubBqqDS66pLbHhm0x1+8elTAxPGdnHG4l501s9HDAaSGSqOrBgseXZ0dHP2OHg/fNbNRzwGkhqF2eDvVupmNJQ4gNWzT3VWxo7y8iaqrs8MLO5nZmONsvDVUWhWwq7ODE/aeQk93F8KrAprZ2OUaSA2lqwK6P8PMbEMOIIPwqoBmZpW5CcvMzOriAGJmZnVxADEzs7o4gJiZWV0cQMzMrC5KsqOPDZKeIsn+OxZsQRss9dtEY+l5x9Kzgp+3HbwhIl61It+YCiBjiaQFEdHb6nI0y1h63rH0rODnbWduwjIzs7o4gJiZWV0cQEavOa0uQJONpecdS88Kft625T4QMzOri2sgZmZWFwcQMzOriwPICCNpkqT5kh5Iv06sct609JwHJE0rOf7vkh6V9ELZ+ZtIukzSUkm/l7R9vk8yuAY86zskLUqf6duSlB6fKakg6a70dViznqkSSYdIWpKWc0aF96t+bySdmh5fIungrNdspZye96H0e32XpAXNeZLB1fuskjaXdJOkFySdX/aZij/XLRERfo2gF/BVYEa6PQM4r8I5k4AH068T0+2J6Xt7A1sDL5R95v8A/5VuHw9cNgqe9fb0eQX8D3Boenwm8MVWP19alg7gT8AbgY2Bu4HdsnxvgN3S8zcBdkiv05HlmqPpedP3HgK2aPXzNfBZNwXeA3wKOL/sMxV/rlvxcg1k5DkSmJtuzwWmVjjnYGB+RCyLiOXAfOAQgIi4LSKeGOS6VwL7t/Qvm0Tdzyppa2Cz9HkD+FGVz7faO4GlEfFgRLwMXEry3KWqfW+OBC6NiJci4s/A0vR6Wa7ZKnk8b7uq+1kj4m8R8VvgxdKT2+3n2gFk5NmqJAD8Bdiqwjk9wKMl+4+lx2pZ95mIeAV4Fth8eEUdtuE8a0+6XX686DOS7pF0YbWmsSbJ8r2q9r2p9exD/f43Sx7PCxDALyUtlHRyDuWux3CetdY1a/1cN5VXJGxDkq4HXl/hrS+X7kRESBrR47Bb9KzfBc4m+aVzNvB14OMNura1xnsioiBpS2C+pPsj4tetLtRo5wDShiLigGrvSfqrpK0j4om0OvtkhdMKwL4l+9sCNw9y2wKwHfCYpI2ACcAzQyl3PXJ81kK6XXq8kN7zryX3uAD4eb3lb4Div3vRunJWOKf8e1Prs4Nds1Vyed6IKH59UtLPSJqPWh1AhvOsta5Z8ee6FdyENfL0A8WRRtOAqyucMw84SNLEtHnmoPRY1useA9yYtrG2Ut3PmjZ9PSdp77T9/KPFz6fBqOgfgXvzeoAM7gB2lrSDpI1JOlL7y86p9r3pB45PR/LsAOxM0sGa5Zqt0vDnlbSppNcBSNqU5Gegld/TouE8a0W1fq5botUjFfwa2oukffQG4AHgemBSerwX+H7JeR8n6WRcCnys5PhXSdpN16ZfZ6bHXwNckZ5/O/DGUfCsvSS/SP4EnM/6zAsXAYuAe0j+A2/d4uc8DPhjWs4vp8fOAo4Y7HtD0tT3J2AJJaNxKl2zXV6Nfl6SUU53p6/F7fS8w3zWh4BlwAvp/9Xdav1ct+LlVCZmZlYXN2GZmVldHEDMzKwuDiBmZlYXBxAzM6uLA4iZmdXFAcSaRmkGYEnbSLqy1eVpNUmfkvTRVpejSNLkNCPsgKT3SvqgpD+kWWF7JX17kM9fJ6m7zntPlbRbfSW3VvEwXmsaSS9ExGubdK+NIsktNKT3Ml5bJP931tZdwDYk6XjggIj4RLr/C+CcSJL65X3v/wZ+HhFj/g+LkcQ1EGs6SdtLujfdPlHSVZJ+oWQ9j6+WnHeQpFsl3SnpCkmvTY+fLukOSfdKmlPMGizpZknfUrIexL+W3XOmpIsk3QJclP61/dP0OndI2ic9b7KStUcWS/q+pIclbZGWeYmkH5FM4tpO0vT0s/dIOjP9/KaSrpV0d1q+49LjsyTdl577tZIyfTHd3lPSben7P0tn1Ref6TxJt0v6o6T3Vvk3PUXJGhF3S5o1yDV3TP+9F0r6jaRdJe1JMsn0SCVrapxBkk78B5JmS9pX0s/Tz79W0g/T+90j6ej0+EOStki3P5yW+S5J35PUkR5/QcmaNHenZdtK0ruBI4DZ6fk7Svpsyb/XpfX+rFnOWj1T06+x8yJdgwTYHrg33T6RZA2PCSSzch8myQ20BUkuo03T804BTk+3J5Vc8yLg8HT7ZuA/q9x7JrAQ6Er3f0ySgA9gCvCHdPt84NR0+xCShItbpGVeC+ydvncQMIdkTYZxJPm0/h44Grig5L4TSGbUL2F9jb+7pExfTLfvAf4h3T4L+FbJM3093T4MuL7Csx0K/A4YX/rvU+OaNwA7p9t7kaTPKH4vzi+57s1Ab7q9L0kNAeC84rXS/eL6Kw+l/1ZvBq4BOtPj/wl8NN2Oku/XV4HT0u3/Bo4puebjwCal/15+td/LyRStHdwQEc8CSLoPeAPQTbKA0C1pBWNj4Nb0/P0k/V9gPMlCUotJfmEBXFbjPv0RsSrdPgDYTeuXPNksreG8hyQ/FhHxC0nLSz7/cETclm4flL4G0v3XkuRm+g3wdUnnkfzC/Y2SJHkvkvw1/3PKkjdKmkDyS/JX6aG5JOktiq5Kvy4kCWTlDgB+GBEr03Ivq3bN9BnfnW4XP79JhWvWcgBJXifS+y0ve39/4B3AHek9ulifCPNl1j//QuDAKve4B7hEUh/QN8TyWZM4gFg7eKlkew3Jz6VIFor6UOmJkl5D8hdtb0Q8KmkmSc2l6G817lP63jiS2kT5gj21yln6eQHnRsT3yk+S9HaS2sI5km6IiLMkvZPkF+sxwGeA99W6UZniv0/x32Y4xgErImLPYV6nFgFzI+LUCu+tjrRaQe3neT9Jje5w4MuS9ohh9FtZPtwHYu3qNmAfSTvBur6FN7E+WDyd/jV9TJ3X/yXwL8WdtA8A4Bbg2PTYQSTL5FYyD/h4Sb9Mj6QtJW0DrIyIi4HZwNvTcyZExHXA54G3ll4orX0tL+nf+AjwK7KbD3xM0vi0LJOqXTMingP+LOmD6bmS9NaKV619v08Xd/TqBbluAI5RsjZHcW37NwxyzeeBYkbdccB2EXETSdPlBJIanrUZ10CsLUXEU5JOBH4iqdjEclpE/FHJGh73kqxSeEedt/gs8B1J95D8P/g1yfrTZ6b3/AhJk9lfSH65bfALLCJ+KenNwK1preUF4MPATiSdwWuB1cA/k/xivDqtPQn4QoXyTAP+Kw0CDwIfy/ogaVPbnsACSS8D1wFfqnHNE4DvSjoN6CRZavXurPcDziH5t7uXpBZxJuub2YiI+9Jr/zINBqtJAs7DNa55KXCBpM+SNI/9IG2GE/DtiFgxhPJZk3gYr1mJNFitiYhXJL0L+G7OzT1mI5ZrIGYbmgJcnv7l/DLwyRaXx6xtuQZiZmZ1cSe6mZnVxQHEzMzq4gBiZmZ1cQAxM7O6OICYmVld/j/JPxCbEbFNEgAAAABJRU5ErkJggg==\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"tags": [],
"needs_background": "light"
}
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "5uz6NkjJodVv"
},
"source": [
"What about real data? For some reason, coefficients for linear and logistic regression seem to have no clean relation to each other, but still the initialization makes an x4 speedup. "
]
},
{
"cell_type": "code",
"metadata": {
"id": "5CMhS3jEDpzt"
},
"source": [
"X, y = datasets.load_breast_cancer(return_X_y=True)"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "1ZxLCixMsJkO",
"outputId": "4b4350bf-fc68-4888-8b9d-08b00fecdf86"
},
"source": [
"l1 = LogisticRegression(max_iter=10000).fit(X, y)\r\n",
"print(l1.n_iter_, 'iterations with default initialization')\r\n",
"\r\n",
"linreg = Ridge(alpha=0.01).fit(X, y)\r\n",
"l2 = LogisticRegression(warm_start=True, max_iter=10000)\r\n",
"ym = y.mean()\r\n",
"l2.coef_ = linreg.coef_[np.newaxis, :] / (ym * (1-ym))\r\n",
"l2.intercept_ = logit(ym)[np.newaxis] - np.dot(X.mean(axis=0), l2.coef_[0])\r\n",
"l2.classes_ = np.array([0, 1])\r\n",
"l2.fit(X, y)\r\n",
"print(l2.n_iter_, 'iterations with initialization from linear model')"
],
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"text": [
"[2127] iterations with default initialization\n",
"[554] iterations with initialization from linear model\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 279
},
"id": "QEybmq7fsNih",
"outputId": "92b4964c-5ff0-4415-df61-96cfa1cd3c15"
},
"source": [
"plt.scatter(linreg.coef_, l2.coef_);\r\n",
"plt.xlabel('linear regression coefficients'); plt.ylabel('logistic model coefficients');"
],
"execution_count": null,
"outputs": [
{
"output_type": "display_data",
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"tags": [],
"needs_background": "light"
}
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "0SF3WRJlsjGx"
},
"source": [
""
],
"execution_count": null,
"outputs": []
}
]
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment