Created
September 26, 2020 07:12
-
-
Save shravankumar147/472e5bc002896e24f47960d5b13580d4 to your computer and use it in GitHub Desktop.
Gradient Descent Intutive understanding
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| { | |
| "nbformat": 4, | |
| "nbformat_minor": 0, | |
| "metadata": { | |
| "colab": { | |
| "name": "Gradient Descent Intutive understanding", | |
| "provenance": [], | |
| "collapsed_sections": [], | |
| "include_colab_link": true | |
| }, | |
| "kernelspec": { | |
| "name": "python3", | |
| "display_name": "Python 3" | |
| } | |
| }, | |
| "cells": [ | |
| { | |
| "cell_type": "markdown", | |
| "metadata": { | |
| "id": "view-in-github", | |
| "colab_type": "text" | |
| }, | |
| "source": [ | |
| "<a href=\"https://colab.research.google.com/gist/shravankumar147/472e5bc002896e24f47960d5b13580d4/gradient-descent-intutive-understanding.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": { | |
| "id": "eJbVEriNZqr2" | |
| }, | |
| "source": [ | |
| "**Gradien Descent - Mathematical understanding**" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": { | |
| "id": "-Z8SQ3aOYz9o" | |
| }, | |
| "source": [ | |
| "$f(x) = 2x^2 cos(x)$\n", | |
| "\n", | |
| "\n", | |
| "$f^{'}(x) = 4xcos(x) - 2x^{2}sin(x) $\n", | |
| "\n", | |
| "Parameter update rule: \n", | |
| "\n", | |
| "$x(t+1) = x(t) - \\alpha\\times f^{'}(x(t))$" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": { | |
| "id": "Q60spymnnYDF" | |
| }, | |
| "source": [ | |
| "\n", | |
| "\n", | |
| "\n" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "metadata": { | |
| "id": "aR_BA4UDZ51V" | |
| }, | |
| "source": [ | |
| "# import required packages\n", | |
| "import numpy as np\n", | |
| "\n", | |
| "import matplotlib.pyplot as plt\n", | |
| "%matplotlib inline" | |
| ], | |
| "execution_count": null, | |
| "outputs": [] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": { | |
| "id": "Nkd_lyHjKH5K" | |
| }, | |
| "source": [ | |
| "$f(x) = 2x^2 cos(x)$\n", | |
| "\n", | |
| "\n", | |
| "\n" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "metadata": { | |
| "id": "zrw7zHiAZmta" | |
| }, | |
| "source": [ | |
| "# define helper functions\n", | |
| "\n", | |
| "# Objective/ Cost function\n", | |
| "def f(x):\n", | |
| " return 2 * x * x * np.cos(x)" | |
| ], | |
| "execution_count": null, | |
| "outputs": [] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": { | |
| "id": "xILmpSfRKiut" | |
| }, | |
| "source": [ | |
| "\n", | |
| "$f^{'}(x) = 4xcos(x) - 2x^{2}sin(x) $\n" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "metadata": { | |
| "id": "PRV0KhbiaOLI" | |
| }, | |
| "source": [ | |
| "# Derivative function for the Cost \n", | |
| "def df(x):\n", | |
| " return 4 * x * np.cos(x) - 2 * x * x * np.sin(x)" | |
| ], | |
| "execution_count": null, | |
| "outputs": [] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "metadata": { | |
| "id": "aFi7_pc6dlcE" | |
| }, | |
| "source": [ | |
| "# Helper function to visualization\n", | |
| "def draw_update(ax, t, x_new):\n", | |
| " ax.plot(x_new,f(x_new), 'ro')" | |
| ], | |
| "execution_count": null, | |
| "outputs": [] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": { | |
| "id": "5gu9CDj5MmyZ" | |
| }, | |
| "source": [ | |
| "\n", | |
| "Parameter update rule: \n", | |
| "\n", | |
| "$x(t+1) = x(t) - \\alpha\\times \\frac{df(x)}{dx}$\n", | |
| "\n", | |
| "(or)\n", | |
| "\n", | |
| "$x(t+1) = x(t) - \\alpha\\times f^{'}(x(t))$" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "metadata": { | |
| "id": "YJrEcujCd61f" | |
| }, | |
| "source": [ | |
| "# parameter update function\n", | |
| "def param_update(x_t, alpha, slope):\n", | |
| " \"\"\" \n", | |
| " slope = f'(x)\n", | |
| " \"\"\"\n", | |
| " return x_t - alpha*slope" | |
| ], | |
| "execution_count": null, | |
| "outputs": [] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": { | |
| "id": "n2gmwxrxFfNC" | |
| }, | |
| "source": [ | |
| "\n", | |
| "\n", | |
| "> \n", | |
| "\n", | |
| "\n", | |
| "\n" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "metadata": { | |
| "id": "Q8hIPeesFOTL" | |
| }, | |
| "source": [ | |
| "" | |
| ], | |
| "execution_count": null, | |
| "outputs": [] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "metadata": { | |
| "id": "fZ_l6eZsaA_L" | |
| }, | |
| "source": [ | |
| "# let's visualize the cost function in given window [-5, 5]\n", | |
| "t = np.linspace(-5, 5)" | |
| ], | |
| "execution_count": null, | |
| "outputs": [] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "metadata": { | |
| "id": "tdORPf3baJjw", | |
| "outputId": "9c4ec52c-9f01-4385-9662-20b45d019c82", | |
| "colab": { | |
| "base_uri": "https://localhost:8080/", | |
| "height": 265 | |
| } | |
| }, | |
| "source": [ | |
| "plt.plot(t, f(t))\n", | |
| "plt.grid(True)" | |
| ], | |
| "execution_count": null, | |
| "outputs": [ | |
| { | |
| "output_type": "display_data", | |
| "data": { | |
| "image/png": "\n", | |
| "text/plain": [ | |
| "<Figure size 432x288 with 1 Axes>" | |
| ] | |
| }, | |
| "metadata": { | |
| "tags": [], | |
| "needs_background": "light" | |
| } | |
| } | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": { | |
| "id": "uq3oSArhmVGv" | |
| }, | |
| "source": [ | |
| "**Local Minima 1**" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "metadata": { | |
| "id": "PoromW_tcIEv", | |
| "outputId": "3b4ed971-9a29-4c03-8b8c-e3902c81c5e7", | |
| "colab": { | |
| "base_uri": "https://localhost:8080/", | |
| "height": 1000 | |
| } | |
| }, | |
| "source": [ | |
| "x = -5.\n", | |
| "alpha = 0.05\n", | |
| "x_cache, dfx_cache = [],[]\n", | |
| "ax = plt.axes()\n", | |
| "t = np.linspace(-5, 5)\n", | |
| "ax.plot(t, f(t))\n", | |
| "ax.grid(True)\n", | |
| "for i in range(50):\n", | |
| " slope = df(x)\n", | |
| " x = param_update(x , alpha, slope)\n", | |
| " x_cache.append(x)\n", | |
| " dfx_cache.append(slope)\n", | |
| " # print(x, slope)\n", | |
| " if slope<0:\n", | |
| " print('Go to Right|| -->> || ')\n", | |
| " else:\n", | |
| " print('Go to Left || <<-- ||')\n", | |
| " draw_update(ax, t, x)\n", | |
| "print(f'Best Value: x={x}')\n", | |
| "print(f'Best Funtion Value: f(x)={f(x)}')" | |
| ], | |
| "execution_count": null, | |
| "outputs": [ | |
| { | |
| "output_type": "stream", | |
| "text": [ | |
| "Go to Right|| -->> || \n", | |
| "Go to Left || <<-- ||\n", | |
| "Go to Left || <<-- ||\n", | |
| "Go to Right|| -->> || \n", | |
| "Go to Left || <<-- ||\n", | |
| "Go to Right|| -->> || \n", | |
| "Go to Left || <<-- ||\n", | |
| "Go to Right|| -->> || \n", | |
| "Go to Left || <<-- ||\n", | |
| "Go to Right|| -->> || \n", | |
| "Go to Left || <<-- ||\n", | |
| "Go to Right|| -->> || \n", | |
| "Go to Left || <<-- ||\n", | |
| "Go to Right|| -->> || \n", | |
| "Go to Left || <<-- ||\n", | |
| "Go to Right|| -->> || \n", | |
| "Go to Left || <<-- ||\n", | |
| "Go to Right|| -->> || \n", | |
| "Go to Left || <<-- ||\n", | |
| "Go to Right|| -->> || \n", | |
| "Go to Left || <<-- ||\n", | |
| "Go to Right|| -->> || \n", | |
| "Go to Left || <<-- ||\n", | |
| "Go to Right|| -->> || \n", | |
| "Go to Left || <<-- ||\n", | |
| "Go to Right|| -->> || \n", | |
| "Go to Left || <<-- ||\n", | |
| "Go to Right|| -->> || \n", | |
| "Go to Left || <<-- ||\n", | |
| "Go to Right|| -->> || \n", | |
| "Go to Left || <<-- ||\n", | |
| "Go to Right|| -->> || \n", | |
| "Go to Left || <<-- ||\n", | |
| "Go to Right|| -->> || \n", | |
| "Go to Left || <<-- ||\n", | |
| "Go to Right|| -->> || \n", | |
| "Go to Left || <<-- ||\n", | |
| "Go to Right|| -->> || \n", | |
| "Go to Left || <<-- ||\n", | |
| "Go to Right|| -->> || \n", | |
| "Go to Left || <<-- ||\n", | |
| "Go to Right|| -->> || \n", | |
| "Go to Left || <<-- ||\n", | |
| "Go to Right|| -->> || \n", | |
| "Go to Left || <<-- ||\n", | |
| "Go to Right|| -->> || \n", | |
| "Go to Left || <<-- ||\n", | |
| "Go to Right|| -->> || \n", | |
| "Go to Left || <<-- ||\n", | |
| "Go to Right|| -->> || \n", | |
| "Best Value: x=-3.643597164971067\n", | |
| "Best Funtion Value: f(x)=-23.27565842351116\n" | |
| ], | |
| "name": "stdout" | |
| }, | |
| { | |
| "output_type": "display_data", | |
| "data": { | |
| "image/png": "\n", | |
| "text/plain": [ | |
| "<Figure size 432x288 with 1 Axes>" | |
| ] | |
| }, | |
| "metadata": { | |
| "tags": [], | |
| "needs_background": "light" | |
| } | |
| } | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": { | |
| "id": "qf8PFmlhPFIa" | |
| }, | |
| "source": [ | |
| "" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "metadata": { | |
| "id": "08omIOLdPEaM" | |
| }, | |
| "source": [ | |
| "" | |
| ], | |
| "execution_count": null, | |
| "outputs": [] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "metadata": { | |
| "id": "OOnxtRIdPEXq" | |
| }, | |
| "source": [ | |
| "" | |
| ], | |
| "execution_count": null, | |
| "outputs": [] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": { | |
| "id": "6SJd_c9rmQvK" | |
| }, | |
| "source": [ | |
| "Local Minima 2" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "metadata": { | |
| "id": "kAFKFw_viwIb", | |
| "outputId": "3aa92db5-0b55-4747-bce5-84a6a8bdb7e5", | |
| "colab": { | |
| "base_uri": "https://localhost:8080/", | |
| "height": 1000 | |
| } | |
| }, | |
| "source": [ | |
| "\n", | |
| "x = -1.\n", | |
| "alpha = 0.05\n", | |
| "x_cache, dfx_cache = [],[]\n", | |
| "ax = plt.axes()\n", | |
| "t = np.linspace(-5, 5)\n", | |
| "ax.plot(t, f(t))\n", | |
| "ax.grid(True)\n", | |
| "for i in range(50):\n", | |
| " slope = df(x)\n", | |
| " x = param_update(x , alpha, slope)\n", | |
| " x_cache.append(x)\n", | |
| " dfx_cache.append(slope)\n", | |
| " # print(x, slope)\n", | |
| " if slope<0:\n", | |
| " print('Go to Right|| -->> || ')\n", | |
| " else:\n", | |
| " print('Go to Left || <<-- ||')\n", | |
| " draw_update(ax, t, x)\n", | |
| "print(f'Best Value: x={x}')\n", | |
| "print(f'Best Funtion Value: f(x)={f(x)}')" | |
| ], | |
| "execution_count": null, | |
| "outputs": [ | |
| { | |
| "output_type": "stream", | |
| "text": [ | |
| "Go to Right|| -->> || \n", | |
| "Go to Right|| -->> || \n", | |
| "Go to Right|| -->> || \n", | |
| "Go to Right|| -->> || \n", | |
| "Go to Right|| -->> || \n", | |
| "Go to Right|| -->> || \n", | |
| "Go to Right|| -->> || \n", | |
| "Go to Right|| -->> || \n", | |
| "Go to Right|| -->> || \n", | |
| "Go to Right|| -->> || \n", | |
| "Go to Right|| -->> || \n", | |
| "Go to Right|| -->> || \n", | |
| "Go to Right|| -->> || \n", | |
| "Go to Right|| -->> || \n", | |
| "Go to Right|| -->> || \n", | |
| "Go to Right|| -->> || \n", | |
| "Go to Right|| -->> || \n", | |
| "Go to Right|| -->> || \n", | |
| "Go to Right|| -->> || \n", | |
| "Go to Right|| -->> || \n", | |
| "Go to Right|| -->> || \n", | |
| "Go to Right|| -->> || \n", | |
| "Go to Right|| -->> || \n", | |
| "Go to Right|| -->> || \n", | |
| "Go to Right|| -->> || \n", | |
| "Go to Right|| -->> || \n", | |
| "Go to Right|| -->> || \n", | |
| "Go to Right|| -->> || \n", | |
| "Go to Right|| -->> || \n", | |
| "Go to Right|| -->> || \n", | |
| "Go to Right|| -->> || \n", | |
| "Go to Right|| -->> || \n", | |
| "Go to Right|| -->> || \n", | |
| "Go to Right|| -->> || \n", | |
| "Go to Right|| -->> || \n", | |
| "Go to Right|| -->> || \n", | |
| "Go to Right|| -->> || \n", | |
| "Go to Right|| -->> || \n", | |
| "Go to Right|| -->> || \n", | |
| "Go to Right|| -->> || \n", | |
| "Go to Right|| -->> || \n", | |
| "Go to Right|| -->> || \n", | |
| "Go to Right|| -->> || \n", | |
| "Go to Right|| -->> || \n", | |
| "Go to Right|| -->> || \n", | |
| "Go to Right|| -->> || \n", | |
| "Go to Right|| -->> || \n", | |
| "Go to Right|| -->> || \n", | |
| "Go to Right|| -->> || \n", | |
| "Go to Right|| -->> || \n", | |
| "Best Value: x=-6.66338443154018e-05\n", | |
| "Best Funtion Value: f(x)=8.880138396784195e-09\n" | |
| ], | |
| "name": "stdout" | |
| }, | |
| { | |
| "output_type": "display_data", | |
| "data": { | |
| "image/png": "\n", | |
| "text/plain": [ | |
| "<Figure size 432x288 with 1 Axes>" | |
| ] | |
| }, | |
| "metadata": { | |
| "tags": [], | |
| "needs_background": "light" | |
| } | |
| } | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": { | |
| "id": "AnQ8mzMOnDRv" | |
| }, | |
| "source": [ | |
| "Local Minima 3" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "metadata": { | |
| "id": "9mXsahtOmG3N", | |
| "outputId": "62a471ee-6f48-489b-966d-cd99d3ee8296", | |
| "colab": { | |
| "base_uri": "https://localhost:8080/", | |
| "height": 1000 | |
| } | |
| }, | |
| "source": [ | |
| "\n", | |
| "x = 2.\n", | |
| "alpha = 0.05\n", | |
| "x_cache, dfx_cache = [],[]\n", | |
| "ax = plt.axes()\n", | |
| "t = np.linspace(-5, 5)\n", | |
| "ax.plot(t, f(t))\n", | |
| "ax.grid(True)\n", | |
| "for i in range(50):\n", | |
| " slope = df(x)\n", | |
| " x = param_update(x , alpha, slope)\n", | |
| " x_cache.append(x)\n", | |
| " dfx_cache.append(slope)\n", | |
| " # print(x, slope)\n", | |
| " if slope<0:\n", | |
| " print('Go to Right|| -->> || ')\n", | |
| " else:\n", | |
| " print('Go to Left || <<-- ||')\n", | |
| " draw_update(ax, t, x)\n", | |
| "print(f'Best Value: x={x}')\n", | |
| "print(f'Best Funtion Value: f(x)={f(x)}')" | |
| ], | |
| "execution_count": null, | |
| "outputs": [ | |
| { | |
| "output_type": "stream", | |
| "text": [ | |
| "Go to Right|| -->> || \n", | |
| "Go to Right|| -->> || \n", | |
| "Go to Right|| -->> || \n", | |
| "Go to Left || <<-- ||\n", | |
| "Go to Right|| -->> || \n", | |
| "Go to Left || <<-- ||\n", | |
| "Go to Right|| -->> || \n", | |
| "Go to Left || <<-- ||\n", | |
| "Go to Right|| -->> || \n", | |
| "Go to Left || <<-- ||\n", | |
| "Go to Right|| -->> || \n", | |
| "Go to Left || <<-- ||\n", | |
| "Go to Right|| -->> || \n", | |
| "Go to Left || <<-- ||\n", | |
| "Go to Right|| -->> || \n", | |
| "Go to Left || <<-- ||\n", | |
| "Go to Right|| -->> || \n", | |
| "Go to Left || <<-- ||\n", | |
| "Go to Right|| -->> || \n", | |
| "Go to Left || <<-- ||\n", | |
| "Go to Right|| -->> || \n", | |
| "Go to Left || <<-- ||\n", | |
| "Go to Right|| -->> || \n", | |
| "Go to Left || <<-- ||\n", | |
| "Go to Right|| -->> || \n", | |
| "Go to Left || <<-- ||\n", | |
| "Go to Right|| -->> || \n", | |
| "Go to Left || <<-- ||\n", | |
| "Go to Right|| -->> || \n", | |
| "Go to Left || <<-- ||\n", | |
| "Go to Right|| -->> || \n", | |
| "Go to Left || <<-- ||\n", | |
| "Go to Right|| -->> || \n", | |
| "Go to Left || <<-- ||\n", | |
| "Go to Right|| -->> || \n", | |
| "Go to Left || <<-- ||\n", | |
| "Go to Right|| -->> || \n", | |
| "Go to Left || <<-- ||\n", | |
| "Go to Right|| -->> || \n", | |
| "Go to Left || <<-- ||\n", | |
| "Go to Right|| -->> || \n", | |
| "Go to Left || <<-- ||\n", | |
| "Go to Right|| -->> || \n", | |
| "Go to Left || <<-- ||\n", | |
| "Go to Right|| -->> || \n", | |
| "Go to Left || <<-- ||\n", | |
| "Go to Right|| -->> || \n", | |
| "Go to Left || <<-- ||\n", | |
| "Go to Right|| -->> || \n", | |
| "Go to Left || <<-- ||\n", | |
| "Best Value: x=3.6435971637604454\n", | |
| "Best Funtion Value: f(x)=-23.275658423511157\n" | |
| ], | |
| "name": "stdout" | |
| }, | |
| { | |
| "output_type": "display_data", | |
| "data": { | |
| "image/png": "\n", | |
| "text/plain": [ | |
| "<Figure size 432x288 with 1 Axes>" | |
| ] | |
| }, | |
| "metadata": { | |
| "tags": [], | |
| "needs_background": "light" | |
| } | |
| } | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": { | |
| "id": "zMPLt0tGquMo" | |
| }, | |
| "source": [ | |
| "**Final Observations:** \n", | |
| "\n", | |
| "1. Local Minima is not always global minima\n", | |
| "\n", | |
| "2. It depends on the initial startinig poing of the parameter\n", | |
| "\n", | |
| "3. Gradient descent will help you to navigate in the right direction to find the optimal value for the cost function\n", | |
| "\n", | |
| "**Exercises:** \n", | |
| "\n", | |
| "1. Try Changing the learning rate (alpha), such as 0.001 and 1. Observe the behaviour of your gradient descent funtion.\n", | |
| "\n", | |
| "2. Use random to generate initial x value. \n", | |
| "\n", | |
| "3. Try to solve a two dimentional optimization f(x1, x2)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "metadata": { | |
| "id": "jDZMSqhEN-lT" | |
| }, | |
| "source": [ | |
| "" | |
| ], | |
| "execution_count": null, | |
| "outputs": [] | |
| } | |
| ] | |
| } |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment