Skip to content

Instantly share code, notes, and snippets.

@Gijs-Koot
Created October 2, 2019 13:34
Show Gist options
  • Save Gijs-Koot/9767231a5f594c3e98abb3a5f3d707ab to your computer and use it in GitHub Desktop.
Save Gijs-Koot/9767231a5f594c3e98abb3a5f3d707ab to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Regression 101\n",
"\n",
"In this training, you will learn\n",
"\n",
"* What the probabilistic model is underlying least squares regression\n",
"* What the likelihood function is\n",
"* What to think of parameter estimation in maximum likelihood estimations and how to interpret confidence intervals\n",
"* Why there is more uncertainty in your predictions than just the parameter uncertainties, what they are, and how to deal with those\n",
"* How to implement a regression in with `tensorflow` and compare it with the output from `statsmodels`\n",
"\n",
"This training should get you started in the direction of\n",
"\n",
"* Understanding what `tensorflow` does\n",
"* Implementing more complicated models\n",
"* Bayesian Statistics\n",
"* Generalized linear models, regression splines"
]
},
{
"cell_type": "code",
"execution_count": 110,
"metadata": {},
"outputs": [],
"source": [
"% matplotlib inline\n",
"\n",
"import pandas as pd\n",
"import numpy as np\n",
"import matplotlib.pyplot as plt\n",
"from scipy import stats"
]
},
{
"cell_type": "code",
"execution_count": 113,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>x</th>\n",
" <th>y</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>4.085283</td>\n",
" <td>23.957436</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>1.083008</td>\n",
" <td>8.632136</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>3.534593</td>\n",
" <td>20.681548</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>3.995216</td>\n",
" <td>22.626877</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>2.994028</td>\n",
" <td>18.836193</td>\n",
" </tr>\n",
" <tr>\n",
" <th>5</th>\n",
" <td>1.899187</td>\n",
" <td>14.902008</td>\n",
" </tr>\n",
" <tr>\n",
" <th>6</th>\n",
" <td>1.792251</td>\n",
" <td>10.031126</td>\n",
" </tr>\n",
" <tr>\n",
" <th>7</th>\n",
" <td>4.042123</td>\n",
" <td>25.267162</td>\n",
" </tr>\n",
" <tr>\n",
" <th>8</th>\n",
" <td>1.676443</td>\n",
" <td>11.839477</td>\n",
" </tr>\n",
" <tr>\n",
" <th>9</th>\n",
" <td>1.353359</td>\n",
" <td>10.657072</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" x y\n",
"0 4.085283 23.957436\n",
"1 1.083008 8.632136\n",
"2 3.534593 20.681548\n",
"3 3.995216 22.626877\n",
"4 2.994028 18.836193\n",
"5 1.899187 14.902008\n",
"6 1.792251 10.031126\n",
"7 4.042123 25.267162\n",
"8 1.676443 11.839477\n",
"9 1.353359 10.657072"
]
},
"execution_count": 113,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"np.random.seed(10)\n",
"\n",
"n = 10\n",
"\n",
"x = stats.uniform(1, 4).rvs(n)\n",
"y = 3 + 5 * x + stats.norm().rvs(n) * 2\n",
"\n",
"df = pd.DataFrame({\n",
" \"x\": x, \"y\": y\n",
"})\n",
"\n",
"df"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Questions\n",
"\n",
"* Explain a real-life dataset that we could be looking at. What is `x`, what is `y`, what are the observations?\n",
"* `x` is drawn from a uniform distribution, what does that mean in the scatterplot below?\n",
"* Explain the terms in the formula and what they represent in the scatterplot\n",
"* This is theoretical data that is following a linear regression model perfectly. What parts of the formulas are crucial to that fact?\n",
"\n",
"```\n",
"x = stats.uniform(1, 4).rvs(n)\n",
"y = 3 + 5 * x + stats.norm().rvs(n) * 2\n",
"```\n"
]
},
{
"cell_type": "code",
"execution_count": 114,
"metadata": {
"scrolled": true
},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAYwAAAEKCAYAAAAB0GKPAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvhp/UCwAAFihJREFUeJzt3XGQXeV53/Hvs7BeabyaskhrDFpReQzjCfWItbvFtJqmmMREUYlIorSFaWyROKO0DRN76imyk05InU7qKLWTts7Eg4ExTgm2x2sXamMb1bjDMI0xKypkiIihDK5WokgWwkhjabPLffrHHrWr9V3pRey5597d72fmzp7znvfcfV6O0E/nvOfeE5mJJEln09d0AZKk3mBgSJKKGBiSpCIGhiSpiIEhSSpiYEiSihgYkqQiBoYkqYiBIUkqcn7TBSymNWvW5Pr165suQ5J6xu7du3+QmcMlfZdUYKxfv56JiYmmy5CknhER3y/t6yUpSVIRA0OSVMTAkCQVMTAkSUUMDElSEQNDknrEkeNTPLH/ZY4cn2rk9y+p22olaam6b88Bdozvpb+vj+lWi51bN7BldG1Ha/AMQ5K63JHjU+wY38vJ6RbHpmY4Od3i1vG9HT/TqC0wImJdRHwrIvZFxFMR8YGq/Xcj4kBE7KlemxfYf1NE/FVEPBsRH66rTknqdpNHT9Dfd/pf1/19fUwePdHROuq8JDUDfCgzH4+IVcDuiNhVbfujzPz3C+0YEecBfwK8B5gEHouI+zPzL2usV5K60sjQSqZbrdPaplstRoZWdrSO2s4wMvOFzHy8Wj4G7ANKL7hdBTybmc9l5l8DnwNuqKdSSepuqwcH2Ll1Ayv6+1g1cD4r+vvYuXUDqwcHOlpHRya9I2I98A7gUWAjcEtEvA+YYPYs5Oi8XdYC++esTwLvqr9SSepOW0bXsvGyNUwePcHI0MqOhwV0YNI7IgaBceCDmfkK8KfAW4FR4AXg4+12a9OWC7z/9oiYiIiJw4cPL1LVktR9Vg8OcOW6CxoJC6g5MCKin9mwuCczvwSQmS9m5quZ2QI+zezlp/kmgXVz1keAg+1+R2benpljmTk2PFz0Db2SpHNQ511SAdwJ7MvMT8xpv3hOt18Anmyz+2PA5RHxloh4A3AjcH9dtUqSzq7OOYyNwHuB70bEnqrtt4CbImKU2UtMzwO/DhARlwB3ZObmzJyJiFuAbwDnAXdl5lM11ipJOovaAiMzH6H9XMQDC/Q/CGyes/7AQn0lSZ3nJ70lSUUMDElSEQNDklTEwJAkFTEwJKlHdfr5GD4PQ5J6UBPPx/AMQ5J6TFPPxzAwJKnHNPV8DANDknpMU8/HMDAkqcc09XwMJ70lqQc18XwMA0OSetTqwYGOPhvDS1KSpCIGhiSpiIEhSSpiYEiSihgYkqQiBoYkqYiBIUkqYmBIkooYGJLUhTr9rIsStX3SOyLWAZ8F3gy0gNsz8z9ExB8CPwf8NfC/gF/JzJfb7P88cAx4FZjJzLG6apWkbtLEsy5K1HmGMQN8KDN/Arga+I2IuALYBbw9MzcA3wM+cob3eHdmjhoWkpaLpp51UaK2wMjMFzLz8Wr5GLAPWJuZD2bmTNXt28BIXTVIUq9p6lkXJToyhxER64F3AI/O2/SrwNcW2C2BByNid0RsP8N7b4+IiYiYOHz48GKUK0mNaepZFyVqD4yIGATGgQ9m5itz2n+b2ctW9yyw68bMfCfws8xezvrJdp0y8/bMHMvMseHh4UWuXpI6q6lnXZSo9evNI6Kf2bC4JzO/NKd9G3A98FOZme32zcyD1c9DEfFl4Crg4TrrlaRu0MSzLkrUeZdUAHcC+zLzE3PaNwE7gH+QmT9aYN83An2Zeaxavg74aF21SlK36fSzLkrUeUlqI/Be4NqI2FO9NgOfBFYBu6q2TwFExCUR8UC170XAIxHxBPAd4KuZ+fUaa5UknUVtZxiZ+QgQbTY90Kbt1CWozdXyc8CVddUmSXrt/KS3JKmIgSFJKmJgSJKKGBiSpCIGhiSpiIEhSSpiYEiSihgYkqQiBoYkqYiBIUkqYmBIkooYGJKkIgaGJKmIgSFpSTtyfIon9r/MkeNTTZfS82p94p4kNem+PQfYMb6X/r4+plstdm7dwJbRtU2X1bM8w5C0JB05PsWO8b2cnG5xbGqGk9Mtbh3f65nG62BgSFqSJo+eoL/v9L/i+vv6mDx6oqGKep+BIWlJGhlayXSrdVrbdKvFyNDKhirqfQaGpCVp9eAAO7duYEV/H6sGzmdFfx87t25g9eBA06X1LCe9JS1ZW0bXsvGyNUwePcHI0ErD4nWq7QwjItZFxLciYl9EPBURH6jaL4yIXRHxTPVzaIH9t1V9nomIbXXVKWlpWz04wJXrLjAsFkGdl6RmgA9l5k8AVwO/ERFXAB8GvpmZlwPfrNZPExEXArcB7wKuAm5bKFgkSZ1RW2Bk5guZ+Xi1fAzYB6wFbgDurrrdDfx8m91/BtiVmS9l5lFgF7CprlolSWfXkUnviFgPvAN4FLgoM1+A2VAB3tRml7XA/jnrk1WbJKkhtQdGRAwC48AHM/OV0t3atOUC7789IiYiYuLw4cPnWqYk6SxqDYyI6Gc2LO7JzC9VzS9GxMXV9ouBQ212nQTWzVkfAQ62+x2ZeXtmjmXm2PDw8OIVL0k6TZ13SQVwJ7AvMz8xZ9P9wKm7nrYB97XZ/RvAdRExVE12X1e1SZIaUucZxkbgvcC1EbGnem0GPga8JyKeAd5TrRMRYxFxB0BmvgT8HvBY9fpo1SZJakhktp0a6EljY2M5MTHRdBmS1DMiYndmjpX09atBJElFDAxJUhEDQ5JUxMCQJBUxMCRJRQwMSVIRA0MSR45P8cT+l33etc7IByhJy9x9ew6wY3wv/X19TLda7Ny6gS2jftenfpxnGNIyduT4FDvG93JyusWxqRlOTre4dXyvZxpqy8CQlrHJoyfo7zv9r4H+vj4mj55oqCJ1MwNDWsZGhlYy3Wqd1jbdajEytLKhitTNDAxpGVs9OMDOrRtY0d/HqoHzWdHfx86tG3z+tdpy0lta5raMrmXjZWuYPHqCkaGVhoUWZGBIYvXggEGhs/KSlCSpiIEhSSpiYEiSihgYkqQiBoYkqYiBIUkqYmBIkorU9jmMiLgLuB44lJlvr9o+D7yt6nIB8HJmjrbZ93ngGPAqMJOZY3XVKUkqU+cH9z4DfBL47KmGzPwnp5Yj4uPAD8+w/7sz8we1VSdJek1qC4zMfDgi1rfbFhEB/GPg2rp+vyRpcTU1h/H3gRcz85kFtifwYETsjojtHaxLkrSApr5L6ibg3jNs35iZByPiTcCuiHg6Mx9u17EKlO0Al1566eJXKkkCGjjDiIjzgV8EPr9Qn8w8WP08BHwZuOoMfW/PzLHMHBseHl7sciVJlSYuSf008HRmTrbbGBFvjIhVp5aB64AnO1ifJKmN2gIjIu4F/gJ4W0RMRsT7q003Mu9yVERcEhEPVKsXAY9ExBPAd4CvZubX66pTklTmrHMYEXELcE9mHn0tb5yZNy3QfnObtoPA5mr5OeDK1/K7JEn1KznDeDPwWER8ISI2VbfESpKWmbMGRmb+a+By4E7gZuCZiPj9iHhrzbVJkrpI0RxGZibwf6rXDDAEfDEidtZYmySpi5TMYfwmsA34AXAH8K8yczoi+oBngFvrLVGS1A1KPri3BvjFzPz+3MbMbEXE9fWUJUnqNmcNjMz8nTNs27e45UiSupXPw5AkFTEwJElFDAxJUhEDQ5JUxMCQJBUxMCRJRQwMSVIRA0OSVMTAkCQVMTAkSUUMDElSEQNDklTEwJAkFTEwJElFDAxJUpHaAiMi7oqIQxHx5Jy2342IAxGxp3ptXmDfTRHxVxHxbER8uK4aJUnl6jzD+AywqU37H2XmaPV6YP7GiDgP+BPgZ4ErgJsi4ooa69QiOnJ8iif2v8yR41NNlyJpkZU8ovWcZObDEbH+HHa9Cng2M58DiIjPATcAf7l41akO9+05wI7xvfT39THdarFz6wa2jK5tuixJi6SJOYxbImJvdclqqM32tcD+OeuTVVtbEbE9IiYiYuLw4cOLXasKHTk+xY7xvZycbnFsaoaT0y1uHd/rmYa0hHQ6MP4UeCswCrwAfLxNn2jTlgu9YWbenpljmTk2PDy8OFXqNZs8eoL+vtP/OPX39TF59ERDFUlabB0NjMx8MTNfzcwW8GlmLz/NNwmsm7M+AhzsRH06dyNDK5lutU5rm261GBla2VBFkhZbRwMjIi6es/oLwJNtuj0GXB4Rb4mINwA3Avd3oj6du9WDA+zcuoEV/X2sGjifFf197Ny6gdWDA02XJmmR1DbpHRH3AtcAayJiErgNuCYiRpm9xPQ88OtV30uAOzJzc2bORMQtwDeA84C7MvOpuurU4tkyupaNl61h8ugJRoZWGhbSEhOZC04P9JyxsbGcmJhougxJ6hkRsTszx0r6+klvSVIRA0OSVMTAkCQVMTAkSUUMDElSEQNDklTEwJAkFTEwJElFDAxJUhEDQ5JUxMCQJBUxMCRJRQwMSVIRA0OSVMTAkCQVMTAkSUUMDElSEQNDklTEwJAkFTEwJElFaguMiLgrIg5FxJNz2v4wIp6OiL0R8eWIuGCBfZ+PiO9GxJ6ImKirRklSuTrPMD4DbJrXtgt4e2ZuAL4HfOQM+787M0czc6ym+iRJr0FtgZGZDwMvzWt7MDNnqtVvAyN1/X5J0uJqcg7jV4GvLbAtgQcjYndEbO9gTZKkBZzfxC+NiN8GZoB7FuiyMTMPRsSbgF0R8XR1xtLuvbYD2wEuvfTSWuqVJDVwhhER24DrgX+amdmuT2YerH4eAr4MXLXQ+2Xm7Zk5lpljw8PDdZQsSaLDgRERm4AdwJbM/NECfd4YEatOLQPXAU+26ytJ6pw6b6u9F/gL4G0RMRkR7wc+Caxi9jLTnoj4VNX3koh4oNr1IuCRiHgC+A7w1cz8el11SpLK1DaHkZk3tWm+c4G+B4HN1fJzwJV11SVJOjd+0nsZOnJ8iif2v8yR41NNlyKphzRyl5Sac9+eA+wY30t/Xx/TrRY7t25gy+japsuS1AM8w1hGjhyfYsf4Xk5Otzg2NcPJ6Ra3ju/1TENSEQNjGZk8eoL+vtMPeX9fH5NHTzRUkaReYmAsIyNDK5lutU5rm261GBla2VBFknqJgbGMrB4cYOfWDazo72PVwPms6O9j59YNrB4caLo0ST3ASe9lZsvoWjZetobJoycYGVppWEgqZmAsQ6sHBwwKSa+Zl6QkSUUMDElSEQNDklTEwJAkFTEwJElFDAxJUhEDowF+W6ykXuTnMDrMb4uV1Ks8w+ggvy1WUi8zMDrIb4uV1MsMjA7y22Il9TIDo4P8tlhJvcxJ7w7z22Il9apazzAi4q6IOBQRT85puzAidkXEM9XPoQX23Vb1eSYittVZZ6etHhzgynUX9ExYeBuwJKj/ktRngE3z2j4MfDMzLwe+Wa2fJiIuBG4D3gVcBdy2ULCoXvftOcDGP3iIX77jUTb+wUPcv+dA0yVJakitgZGZDwMvzWu+Abi7Wr4b+Pk2u/4MsCszX8rMo8Aufjx4VDNvA5Y0VxOT3hdl5gsA1c83temzFtg/Z32yalMHeRuwpLm69S6paNOWbTtGbI+IiYiYOHz4cM1lLS/eBixpriYC48WIuBig+nmoTZ9JYN2c9RHgYLs3y8zbM3MsM8eGh4cXvdjlzNuAJc3VxG219wPbgI9VP+9r0+cbwO/Pmei+DvhIZ8rTXN4GLOmUWgMjIu4FrgHWRMQks3c+fQz4QkS8H/jfwD+q+o4B/ywzfy0zX4qI3wMeq97qo5k5f/JcHbJ6cMCgkERktp0a6EljY2M5MTHRdBmS1DMiYndmjpX07dZJb0lSlzEwJElFDAxJUhEDQ5JUxMCQJBUxMCRJRQwMSVIRA0OSVMTAwAcESVKJZf+I1vv2HGDH+F76+/qYbrXYuXUDW0b9JnVJmm9Zn2H4gCBJKresA8MHBElSuWUdGD4gSJLKLevA8AFBklRu2U96+4AgSSqz7AMDfECQJJVY1pekJEnlDAxJUhEDQ5JUxMCQJBUxMCRJRSIzm65h0UTEYeD7C2xeA/ygg+XUaamMxXF0n6UylqUyDqh/LH8zM4dLOi6pwDiTiJjIzLGm61gMS2UsjqP7LJWxLJVxQHeNxUtSkqQiBoYkqchyCozbmy5gES2VsTiO7rNUxrJUxgFdNJZlM4chSXp9ltMZhiTpdVhSgRERd0XEoYh4coHtERH/MSKejYi9EfHOTtdYqmAs10TEDyNiT/X6nU7XWCIi1kXEtyJiX0Q8FREfaNOn649L4Th65ZisiIjvRMQT1Vj+TZs+AxHx+eqYPBoR6ztf6ZkVjuPmiDg855j8WhO1loiI8yLif0bEV9ps647jkZlL5gX8JPBO4MkFtm8GvgYEcDXwaNM1v46xXAN8pek6C8ZxMfDOankV8D3gil47LoXj6JVjEsBgtdwPPApcPa/PvwA+VS3fCHy+6brPcRw3A59sutbC8fxL4M/b/RnqluOxpM4wMvNh4KUzdLkB+GzO+jZwQURc3JnqXpuCsfSEzHwhMx+vlo8B+4C187p1/XEpHEdPqP47H69W+6vX/MnMG4C7q+UvAj8VEdGhEosUjqMnRMQI8A+BOxbo0hXHY0kFRoG1wP4565P06P/0lb9bnY5/LSL+VtPFnE11Gv0OZv8lOFdPHZczjAN65JhUlz/2AIeAXZm54DHJzBngh8DqzlZ5dgXjANhaXer8YkSs63CJpf4YuBVoLbC9K47HcguMdonck/8iAR5n9iP9VwL/CfgvDddzRhExCIwDH8zMV+ZvbrNLVx6Xs4yjZ45JZr6amaPACHBVRLx9XpeeOCYF4/ivwPrM3AD8N/7/v9K7RkRcDxzKzN1n6tamrePHY7kFxiQw918YI8DBhmp5XTLzlVOn45n5ANAfEWsaLqutiOhn9i/ZezLzS2269MRxOds4eumYnJKZLwP/Hdg0b9P/OyYRcT7wN+jiS6QLjSMzj2TmVLX6aeBvd7i0EhuBLRHxPPA54NqI+M/z+nTF8VhugXE/8L7qrpyrgR9m5gtNF3UuIuLNp65hRsRVzB7LI81W9eOqGu8E9mXmJxbo1vXHpWQcPXRMhiPigmp5JfDTwNPzut0PbKuWfwl4KKsZ125RMo55c2FbmJ176iqZ+ZHMHMnM9cxOaD+Umb88r1tXHI8l9UzviLiX2TtV1kTEJHAbsxNhZOangAeYvSPnWeBHwK80U+nZFYzll4B/HhEzwAngxm77H7qyEXgv8N3qWjPAbwGXQk8dl5Jx9MoxuRi4OyLOYzbUvpCZX4mIjwITmXk/s+H4ZxHxLLP/kr2xuXIXVDKO34yILcAMs+O4ubFqX6NuPB5+0luSVGS5XZKSJJ0jA0OSVMTAkCQVMTAkSUUMDElSEQNDklTEwJAkFTEwpJpExN+pvvRuRUS8sXpmw/zvOpJ6hh/ck2oUEf8WWAGsBCYz8981XJJ0zgwMqUYR8QbgMeAk8Pcy89WGS5LOmZekpHpdCAwy+5S+FQ3XIr0unmFINYqI+5n9yuq3ABdn5i0NlySdsyX1bbVSN4mI9wEzmfnn1Teq/o+IuDYzH2q6NulceIYhSSriHIYkqYiBIUkqYmBIkooYGJKkIgaGJKmIgSFJKmJgSJKKGBiSpCL/F1jFpb12pIDQAAAAAElFTkSuQmCC\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"df.plot.scatter(\"x\", \"y\");"
]
},
{
"cell_type": "code",
"execution_count": 115,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"ax = df.plot.scatter(\"x\", \"y\", label = \"observations\")\n",
"ax.plot(x, 3 + 5 * x, color = \"red\", label = \"true model\")\n",
"ax.legend();"
]
},
{
"cell_type": "code",
"execution_count": 116,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"C:\\Users\\gijs.koot\\AppData\\Local\\Continuum\\anaconda3\\lib\\site-packages\\scipy\\stats\\stats.py:1394: UserWarning: kurtosistest only valid for n>=20 ... continuing anyway, n=10\n",
" \"anyway, n=%i\" % int(n))\n"
]
},
{
"data": {
"text/html": [
"<table class=\"simpletable\">\n",
"<caption>OLS Regression Results</caption>\n",
"<tr>\n",
" <th>Dep. Variable:</th> <td>y</td> <th> R-squared: </th> <td> 0.963</td>\n",
"</tr>\n",
"<tr>\n",
" <th>Model:</th> <td>OLS</td> <th> Adj. R-squared: </th> <td> 0.958</td>\n",
"</tr>\n",
"<tr>\n",
" <th>Method:</th> <td>Least Squares</td> <th> F-statistic: </th> <td> 207.8</td>\n",
"</tr>\n",
"<tr>\n",
" <th>Date:</th> <td>Thu, 20 Sep 2018</td> <th> Prob (F-statistic):</th> <td>5.24e-07</td>\n",
"</tr>\n",
"<tr>\n",
" <th>Time:</th> <td>17:00:46</td> <th> Log-Likelihood: </th> <td> -15.556</td>\n",
"</tr>\n",
"<tr>\n",
" <th>No. Observations:</th> <td> 10</td> <th> AIC: </th> <td> 35.11</td>\n",
"</tr>\n",
"<tr>\n",
" <th>Df Residuals:</th> <td> 8</td> <th> BIC: </th> <td> 35.72</td>\n",
"</tr>\n",
"<tr>\n",
" <th>Df Model:</th> <td> 1</td> <th> </th> <td> </td> \n",
"</tr>\n",
"<tr>\n",
" <th>Covariance Type:</th> <td>nonrobust</td> <th> </th> <td> </td> \n",
"</tr>\n",
"</table>\n",
"<table class=\"simpletable\">\n",
"<tr>\n",
" <td></td> <th>coef</th> <th>std err</th> <th>t</th> <th>P>|t|</th> <th>[0.025</th> <th>0.975]</th> \n",
"</tr>\n",
"<tr>\n",
" <th>Intercept</th> <td> 3.2338</td> <td> 1.021</td> <td> 3.167</td> <td> 0.013</td> <td> 0.879</td> <td> 5.588</td>\n",
"</tr>\n",
"<tr>\n",
" <th>x</th> <td> 5.1064</td> <td> 0.354</td> <td> 14.415</td> <td> 0.000</td> <td> 4.290</td> <td> 5.923</td>\n",
"</tr>\n",
"</table>\n",
"<table class=\"simpletable\">\n",
"<tr>\n",
" <th>Omnibus:</th> <td> 0.851</td> <th> Durbin-Watson: </th> <td> 3.017</td>\n",
"</tr>\n",
"<tr>\n",
" <th>Prob(Omnibus):</th> <td> 0.653</td> <th> Jarque-Bera (JB): </th> <td> 0.107</td>\n",
"</tr>\n",
"<tr>\n",
" <th>Skew:</th> <td>-0.252</td> <th> Prob(JB): </th> <td> 0.948</td>\n",
"</tr>\n",
"<tr>\n",
" <th>Kurtosis:</th> <td> 2.942</td> <th> Cond. No. </th> <td> 8.01</td>\n",
"</tr>\n",
"</table><br/><br/>Warnings:<br/>[1] Standard Errors assume that the covariance matrix of the errors is correctly specified."
],
"text/plain": [
"<class 'statsmodels.iolib.summary.Summary'>\n",
"\"\"\"\n",
" OLS Regression Results \n",
"==============================================================================\n",
"Dep. Variable: y R-squared: 0.963\n",
"Model: OLS Adj. R-squared: 0.958\n",
"Method: Least Squares F-statistic: 207.8\n",
"Date: Thu, 20 Sep 2018 Prob (F-statistic): 5.24e-07\n",
"Time: 17:00:46 Log-Likelihood: -15.556\n",
"No. Observations: 10 AIC: 35.11\n",
"Df Residuals: 8 BIC: 35.72\n",
"Df Model: 1 \n",
"Covariance Type: nonrobust \n",
"==============================================================================\n",
" coef std err t P>|t| [0.025 0.975]\n",
"------------------------------------------------------------------------------\n",
"Intercept 3.2338 1.021 3.167 0.013 0.879 5.588\n",
"x 5.1064 0.354 14.415 0.000 4.290 5.923\n",
"==============================================================================\n",
"Omnibus: 0.851 Durbin-Watson: 3.017\n",
"Prob(Omnibus): 0.653 Jarque-Bera (JB): 0.107\n",
"Skew: -0.252 Prob(JB): 0.948\n",
"Kurtosis: 2.942 Cond. No. 8.01\n",
"==============================================================================\n",
"\n",
"Warnings:\n",
"[1] Standard Errors assume that the covariance matrix of the errors is correctly specified.\n",
"\"\"\""
]
},
"execution_count": 116,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"import statsmodels.formula.api as smf\n",
"\n",
"smf.ols('y ~ x', data = df).fit().summary()"
]
},
{
"cell_type": "code",
"execution_count": 117,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"array([[[ 3, 4, 5, 6, 7]],\n",
"\n",
" [[ 6, 8, 10, 12, 14]],\n",
"\n",
" [[ 9, 12, 15, 18, 21]]])"
]
},
"execution_count": 117,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"a = np.array([1, 2, 3])\n",
"b = np.array([3, 4, 5, 6, 7])\n",
"\n",
"a.reshape(-1, 1, 1) @ b.reshape(1, 1, -1)"
]
},
{
"cell_type": "code",
"execution_count": 118,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>a</th>\n",
" <th>b</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>1</td>\n",
" <td>3</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>1</td>\n",
" <td>4</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>1</td>\n",
" <td>5</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>1</td>\n",
" <td>6</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>1</td>\n",
" <td>7</td>\n",
" </tr>\n",
" <tr>\n",
" <th>5</th>\n",
" <td>2</td>\n",
" <td>3</td>\n",
" </tr>\n",
" <tr>\n",
" <th>6</th>\n",
" <td>2</td>\n",
" <td>4</td>\n",
" </tr>\n",
" <tr>\n",
" <th>7</th>\n",
" <td>2</td>\n",
" <td>5</td>\n",
" </tr>\n",
" <tr>\n",
" <th>8</th>\n",
" <td>2</td>\n",
" <td>6</td>\n",
" </tr>\n",
" <tr>\n",
" <th>9</th>\n",
" <td>2</td>\n",
" <td>7</td>\n",
" </tr>\n",
" <tr>\n",
" <th>10</th>\n",
" <td>3</td>\n",
" <td>3</td>\n",
" </tr>\n",
" <tr>\n",
" <th>11</th>\n",
" <td>3</td>\n",
" <td>4</td>\n",
" </tr>\n",
" <tr>\n",
" <th>12</th>\n",
" <td>3</td>\n",
" <td>5</td>\n",
" </tr>\n",
" <tr>\n",
" <th>13</th>\n",
" <td>3</td>\n",
" <td>6</td>\n",
" </tr>\n",
" <tr>\n",
" <th>14</th>\n",
" <td>3</td>\n",
" <td>7</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" a b\n",
"0 1 3\n",
"1 1 4\n",
"2 1 5\n",
"3 1 6\n",
"4 1 7\n",
"5 2 3\n",
"6 2 4\n",
"7 2 5\n",
"8 2 6\n",
"9 2 7\n",
"10 3 3\n",
"11 3 4\n",
"12 3 5\n",
"13 3 6\n",
"14 3 7"
]
},
"execution_count": 118,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"parameters = pd.DataFrame(np.array(np.meshgrid(a, b)).T.reshape(-1, 2), columns = [\"a\", \"b\"])\n",
"parameters"
]
},
{
"cell_type": "code",
"execution_count": 119,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>a</th>\n",
" <th>b</th>\n",
" <th>squared_error</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>1</td>\n",
" <td>3</td>\n",
" <td>680.640539</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>1</td>\n",
" <td>4</td>\n",
" <td>295.519244</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>1</td>\n",
" <td>5</td>\n",
" <td>76.561176</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>1</td>\n",
" <td>6</td>\n",
" <td>23.766336</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>1</td>\n",
" <td>7</td>\n",
" <td>137.134723</td>\n",
" </tr>\n",
" <tr>\n",
" <th>5</th>\n",
" <td>2</td>\n",
" <td>3</td>\n",
" <td>534.511413</td>\n",
" </tr>\n",
" <tr>\n",
" <th>6</th>\n",
" <td>2</td>\n",
" <td>4</td>\n",
" <td>202.301099</td>\n",
" </tr>\n",
" <tr>\n",
" <th>7</th>\n",
" <td>2</td>\n",
" <td>5</td>\n",
" <td>36.254012</td>\n",
" </tr>\n",
" <tr>\n",
" <th>8</th>\n",
" <td>2</td>\n",
" <td>6</td>\n",
" <td>36.370153</td>\n",
" </tr>\n",
" <tr>\n",
" <th>9</th>\n",
" <td>2</td>\n",
" <td>7</td>\n",
" <td>202.649520</td>\n",
" </tr>\n",
" <tr>\n",
" <th>10</th>\n",
" <td>3</td>\n",
" <td>3</td>\n",
" <td>408.382287</td>\n",
" </tr>\n",
" <tr>\n",
" <th>11</th>\n",
" <td>3</td>\n",
" <td>4</td>\n",
" <td>129.082954</td>\n",
" </tr>\n",
" <tr>\n",
" <th>12</th>\n",
" <td>3</td>\n",
" <td>5</td>\n",
" <td>15.946848</td>\n",
" </tr>\n",
" <tr>\n",
" <th>13</th>\n",
" <td>3</td>\n",
" <td>6</td>\n",
" <td>68.973969</td>\n",
" </tr>\n",
" <tr>\n",
" <th>14</th>\n",
" <td>3</td>\n",
" <td>7</td>\n",
" <td>288.164318</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" a b squared_error\n",
"0 1 3 680.640539\n",
"1 1 4 295.519244\n",
"2 1 5 76.561176\n",
"3 1 6 23.766336\n",
"4 1 7 137.134723\n",
"5 2 3 534.511413\n",
"6 2 4 202.301099\n",
"7 2 5 36.254012\n",
"8 2 6 36.370153\n",
"9 2 7 202.649520\n",
"10 3 3 408.382287\n",
"11 3 4 129.082954\n",
"12 3 5 15.946848\n",
"13 3 6 68.973969\n",
"14 3 7 288.164318"
]
},
"execution_count": 119,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"parameters[\"squared_error\"] = ((parameters.a.values + parameters.b.values.reshape(1, -1) * x.reshape(-1, 1) - y.reshape(-1, 1)) ** 2).sum(axis = 0).flatten()\n",
"parameters"
]
},
{
"cell_type": "code",
"execution_count": 120,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"<matplotlib.axes._subplots.AxesSubplot at 0x1a91c7ac0f0>"
]
},
"execution_count": 120,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 432x288 with 2 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"import matplotlib\n",
"\n",
"parameters.plot.scatter(\"a\", \"b\", c = \"squared_error\", cmap = matplotlib.cm.Reds)"
]
},
{
"cell_type": "code",
"execution_count": 121,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 432x288 with 2 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"a = np.arange(0, 10, .1)\n",
"b = np.arange(0, 10, .1)\n",
"\n",
"parameters = pd.DataFrame(np.array(np.meshgrid(a, b)).T.reshape(-1, 2), columns = [\"a\", \"b\"])\n",
"parameters[\"squared_error\"] = ((parameters.a.values + parameters.b.values.reshape(1, -1) * x.reshape(-1, 1) - y.reshape(-1, 1)) ** 2).sum(axis = 0).flatten()\n",
"\n",
"ax = parameters.plot.scatter(\"a\", \"b\", c = \"squared_error\", cmap = matplotlib.cm.Reds)\n",
"ax.plot([3], [5], \"+\");"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Connection between least squares and maximum likelihood\n",
"\n",
"$\n",
"\\begin{align}\n",
"Y &\\sim a + b\\cdot X + e\\\\\n",
"e &\\sim \\mathcal{N}\\left(0, s\\right)\n",
"\\end{align}\n",
"$\n",
"\n",
"-- \n",
"\n",
"$\n",
"p(x) = \\frac{1}{\\sigma\\sqrt{2\\pi}}e^{-\\frac{1}{2}\\left(\\frac{x - \\mu}{\\sigma}\\right)^2}\n",
"$"
]
},
{
"cell_type": "code",
"execution_count": 133,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 720x432 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"plt.figure(figsize = (10, 6))\n",
"\n",
"for mu in np.array([3, 5, 8]):\n",
" for sigma in np.array([.1, .2, .5]):\n",
" x = np.arange(0, 10, .1)\n",
" y = np.exp(- ((x - mu) / sigma) ** 2)\n",
" plt.plot(x, y, label = \"mu = {}, sigma = {}\".format(mu, sigma));\n",
" \n",
"plt.legend();"
]
},
{
"cell_type": "code",
"execution_count": 123,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"array([1.43332941, 1.39989887, 1.36834229, 1.33856748, 1.31048854,\n",
" 1.28402542, 1.25910355, 1.23565352, 1.21361074, 1.19291513,\n",
" 1.17351087, 1.15534615, 1.13837294, 1.12254676, 1.10782651,\n",
" 1.09417428, 1.08155519, 1.06993725, 1.05929119, 1.04959041,\n",
" 1.04081077, 1.03293059, 1.02593049, 1.01979334, 1.01450418,\n",
" 1.01005017, 1.00642052, 1.00360649, 1.00160128, 1.00040008,\n",
" 1. , 1.00040008, 1.00160128, 1.00360649, 1.00642052,\n",
" 1.01005017, 1.01450418, 1.01979334, 1.02593049, 1.03293059,\n",
" 1.04081077, 1.04959041, 1.05929119, 1.06993725, 1.08155519,\n",
" 1.09417428, 1.10782651, 1.12254676, 1.13837294, 1.15534615,\n",
" 1.17351087, 1.19291513, 1.21361074, 1.23565352, 1.25910355,\n",
" 1.28402542, 1.31048854, 1.33856748, 1.36834229, 1.39989887,\n",
" 1.43332941, 1.46873282, 1.50621518, 1.54589031, 1.58788033,\n",
" 1.63231622, 1.67933856, 1.7290982 , 1.78175708, 1.83748906,\n",
" 1.89648088, 1.95893312, 2.02506136, 2.09509731, 2.16929016,\n",
" 2.24790799, 2.33123927, 2.41959459, 2.51330847, 2.61274136,\n",
" 2.71828183, 2.83034893, 2.94939481, 3.07590756, 3.21041432,\n",
" 3.35348465, 3.5057343 , 3.66782924, 3.84049014, 4.02449727,\n",
" 4.22069582, 4.43000184, 4.65340868, 4.89199404, 5.14692784,\n",
" 5.41948071, 5.71103347, 6.02308749, 6.3572761 , 6.71537719])"
]
},
"execution_count": 123,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"y"
]
},
{
"cell_type": "code",
"execution_count": 149,
"metadata": {},
"outputs": [],
"source": [
"import tensorflow as tf\n",
"\n",
"# tf.enable_eager_execution()"
]
},
{
"cell_type": "code",
"execution_count": 152,
"metadata": {},
"outputs": [],
"source": [
"tx = tf.convert_to_tensor(x)"
]
},
{
"cell_type": "code",
"execution_count": 153,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"<tf.Tensor: id=1, shape=(100,), dtype=float64, numpy=\n",
"array([0. , 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1. , 1.1, 1.2,\n",
" 1.3, 1.4, 1.5, 1.6, 1.7, 1.8, 1.9, 2. , 2.1, 2.2, 2.3, 2.4, 2.5,\n",
" 2.6, 2.7, 2.8, 2.9, 3. , 3.1, 3.2, 3.3, 3.4, 3.5, 3.6, 3.7, 3.8,\n",
" 3.9, 4. , 4.1, 4.2, 4.3, 4.4, 4.5, 4.6, 4.7, 4.8, 4.9, 5. , 5.1,\n",
" 5.2, 5.3, 5.4, 5.5, 5.6, 5.7, 5.8, 5.9, 6. , 6.1, 6.2, 6.3, 6.4,\n",
" 6.5, 6.6, 6.7, 6.8, 6.9, 7. , 7.1, 7.2, 7.3, 7.4, 7.5, 7.6, 7.7,\n",
" 7.8, 7.9, 8. , 8.1, 8.2, 8.3, 8.4, 8.5, 8.6, 8.7, 8.8, 8.9, 9. ,\n",
" 9.1, 9.2, 9.3, 9.4, 9.5, 9.6, 9.7, 9.8, 9.9])>"
]
},
"execution_count": 153,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"tx"
]
},
{
"cell_type": "code",
"execution_count": 156,
"metadata": {
"scrolled": true
},
"outputs": [
{
"data": {
"text/plain": [
"<tf.Tensor: id=3, shape=(100,), dtype=float64, numpy=\n",
"array([0.000e+00, 1.000e-02, 4.000e-02, 9.000e-02, 1.600e-01, 2.500e-01,\n",
" 3.600e-01, 4.900e-01, 6.400e-01, 8.100e-01, 1.000e+00, 1.210e+00,\n",
" 1.440e+00, 1.690e+00, 1.960e+00, 2.250e+00, 2.560e+00, 2.890e+00,\n",
" 3.240e+00, 3.610e+00, 4.000e+00, 4.410e+00, 4.840e+00, 5.290e+00,\n",
" 5.760e+00, 6.250e+00, 6.760e+00, 7.290e+00, 7.840e+00, 8.410e+00,\n",
" 9.000e+00, 9.610e+00, 1.024e+01, 1.089e+01, 1.156e+01, 1.225e+01,\n",
" 1.296e+01, 1.369e+01, 1.444e+01, 1.521e+01, 1.600e+01, 1.681e+01,\n",
" 1.764e+01, 1.849e+01, 1.936e+01, 2.025e+01, 2.116e+01, 2.209e+01,\n",
" 2.304e+01, 2.401e+01, 2.500e+01, 2.601e+01, 2.704e+01, 2.809e+01,\n",
" 2.916e+01, 3.025e+01, 3.136e+01, 3.249e+01, 3.364e+01, 3.481e+01,\n",
" 3.600e+01, 3.721e+01, 3.844e+01, 3.969e+01, 4.096e+01, 4.225e+01,\n",
" 4.356e+01, 4.489e+01, 4.624e+01, 4.761e+01, 4.900e+01, 5.041e+01,\n",
" 5.184e+01, 5.329e+01, 5.476e+01, 5.625e+01, 5.776e+01, 5.929e+01,\n",
" 6.084e+01, 6.241e+01, 6.400e+01, 6.561e+01, 6.724e+01, 6.889e+01,\n",
" 7.056e+01, 7.225e+01, 7.396e+01, 7.569e+01, 7.744e+01, 7.921e+01,\n",
" 8.100e+01, 8.281e+01, 8.464e+01, 8.649e+01, 8.836e+01, 9.025e+01,\n",
" 9.216e+01, 9.409e+01, 9.604e+01, 9.801e+01])>"
]
},
"execution_count": 156,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"tf.square(tx)"
]
},
{
"cell_type": "code",
"execution_count": 167,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(<tf.Tensor: id=84, shape=(), dtype=float32, numpy=0.70710677>,\n",
" <tf.Tensor: id=79, shape=(), dtype=float32, numpy=3.1415927>)"
]
},
"execution_count": 167,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from math import pi\n",
"\n",
"pi = tf.constant(pi)\n",
"\n",
"tf.sqrt(.5), pi"
]
},
{
"cell_type": "code",
"execution_count": 173,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"<tf.Tensor: id=148, shape=(), dtype=float32, numpy=0.25164604>"
]
},
"execution_count": 173,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
" tf.div(1., (2 * pi * tf.sqrt(.4)))"
]
},
{
"cell_type": "code",
"execution_count": 185,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"tf.Tensor(\n",
"[-1.40992804e-01 -2.35655649e-02 -3.68327589e-05 -6.09704671e-02\n",
" -3.75023362e-01 -2.89459785e+00 -1.86270350e+00 -2.11469516e+00\n",
" -1.04543473e-01 -3.96294989e-01], shape=(10,), dtype=float64)\n",
"tf.Tensor(0.15915494, shape=(), dtype=float32)\n"
]
},
{
"data": {
"text/plain": [
"<tf.Tensor: id=343, shape=(), dtype=float64, numpy=1.0151029856665015>"
]
},
"execution_count": 185,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"np.random.seed(10)\n",
"\n",
"n = 10\n",
"\n",
"x = stats.uniform(1, 4).rvs(n)\n",
"y = 3 + 5 * x + stats.norm().rvs(n) * 2\n",
"\n",
"tx = tf.convert_to_tensor(x)\n",
"ty = tf.convert_to_tensor(y)\n",
"\n",
"def likelihood(a, b, sigma):\n",
" \n",
" tmp = -.5 * tf.square(tf.div(a + b * tx - y, sigma))\n",
" const = tf.div(1., (2. * pi * tf.sqrt(sigma)))\n",
" \n",
" return tf.reduce_sum(tf.cast(const, tf.double) * tf.exp(tmp))\n",
"\n",
"likelihood(3., 5., 1.)"
]
},
{
"cell_type": "code",
"execution_count": 188,
"metadata": {},
"outputs": [],
"source": [
"tfe = tf.contrib.eager\n",
"\n",
"grad = tfe.gradients_function(likelihood)"
]
},
{
"cell_type": "code",
"execution_count": 193,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"tf.Tensor(\n",
"[-29.65723131 -0.9565003 -18.46963644 -22.0529722 -17.13529304\n",
" -13.54313903 -0.21410108 -41.77705614 -3.94846295 -3.37218824], shape=(10,), dtype=float64)\n"
]
},
{
"ename": "InvalidArgumentError",
"evalue": "cannot compute Mul as input #0 was expected to be a double tensor but is a float tensor [Op:Mul] name: mul/",
"output_type": "error",
"traceback": [
"\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[1;31mInvalidArgumentError\u001b[0m Traceback (most recent call last)",
"\u001b[1;32m<ipython-input-193-0f024ea5794c>\u001b[0m in \u001b[0;36m<module>\u001b[1;34m()\u001b[0m\n\u001b[1;32m----> 1\u001b[1;33m \u001b[0mgrad\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mtf\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mcast\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;36m4\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mtf\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mdouble\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mtf\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mcast\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;36m3\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mtf\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mdouble\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mtf\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mcast\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;36m1.\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mtf\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mdouble\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m",
"\u001b[1;32m~\\AppData\\Local\\Continuum\\anaconda3\\lib\\site-packages\\tensorflow\\python\\eager\\backprop.py\u001b[0m in \u001b[0;36mdecorated\u001b[1;34m(*args, **kwds)\u001b[0m\n\u001b[0;32m 367\u001b[0m \u001b[1;34m\"\"\"Computes the gradient of the decorated function.\"\"\"\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 368\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 369\u001b[1;33m \u001b[0m_\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mgrad\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mval_and_grad_function\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mf\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mparams\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mparams\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m*\u001b[0m\u001b[0margs\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m**\u001b[0m\u001b[0mkwds\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 370\u001b[0m \u001b[1;32mreturn\u001b[0m \u001b[0mgrad\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 371\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n",
"\u001b[1;32m~\\AppData\\Local\\Continuum\\anaconda3\\lib\\site-packages\\tensorflow\\python\\eager\\backprop.py\u001b[0m in \u001b[0;36mdecorated\u001b[1;34m(*args, **kwds)\u001b[0m\n\u001b[0;32m 466\u001b[0m raise ValueError(\"Functions to be differentiated cannot \"\n\u001b[0;32m 467\u001b[0m \"receive keyword arguments.\")\n\u001b[1;32m--> 468\u001b[1;33m \u001b[0mval\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mvjp\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mmake_vjp\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mf\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mparams\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m*\u001b[0m\u001b[0margs\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m**\u001b[0m\u001b[0mkwds\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 469\u001b[0m \u001b[1;32mreturn\u001b[0m \u001b[0mval\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mvjp\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mdy\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mdy\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 470\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n",
"\u001b[1;32m~\\AppData\\Local\\Continuum\\anaconda3\\lib\\site-packages\\tensorflow\\python\\eager\\backprop.py\u001b[0m in \u001b[0;36mdecorated\u001b[1;34m(*args, **kwds)\u001b[0m\n\u001b[0;32m 522\u001b[0m \u001b[0msources\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mappend\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0margs\u001b[0m\u001b[1;33m[\u001b[0m\u001b[0mi\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 523\u001b[0m \u001b[0mtape\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mwatch\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0margs\u001b[0m\u001b[1;33m[\u001b[0m\u001b[0mi\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 524\u001b[1;33m \u001b[0mresult\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mf\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m*\u001b[0m\u001b[0margs\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 525\u001b[0m \u001b[1;32mif\u001b[0m \u001b[0mresult\u001b[0m \u001b[1;32mis\u001b[0m \u001b[1;32mNone\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 526\u001b[0m raise ValueError(\"Cannot differentiate a function that returns None; \"\n",
"\u001b[1;32m<ipython-input-185-a13d491ea27c>\u001b[0m in \u001b[0;36mlikelihood\u001b[1;34m(a, b, sigma)\u001b[0m\n\u001b[0;32m 15\u001b[0m \u001b[0mprint\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mtmp\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 16\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m---> 17\u001b[1;33m \u001b[0mconst\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mtf\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mdiv\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;36m1.\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m(\u001b[0m\u001b[1;36m2.\u001b[0m \u001b[1;33m*\u001b[0m \u001b[0mpi\u001b[0m \u001b[1;33m*\u001b[0m \u001b[0mtf\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0msqrt\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0msigma\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 18\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 19\u001b[0m \u001b[0mprint\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mconst\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
"\u001b[1;32m~\\AppData\\Local\\Continuum\\anaconda3\\lib\\site-packages\\tensorflow\\python\\ops\\math_ops.py\u001b[0m in \u001b[0;36mbinary_op_wrapper\u001b[1;34m(x, y)\u001b[0m\n\u001b[0;32m 848\u001b[0m \u001b[1;32mwith\u001b[0m \u001b[0mops\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mname_scope\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;32mNone\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mop_name\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m[\u001b[0m\u001b[0mx\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0my\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m)\u001b[0m \u001b[1;32mas\u001b[0m \u001b[0mname\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 849\u001b[0m \u001b[1;32mif\u001b[0m \u001b[0misinstance\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mx\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mops\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mTensor\u001b[0m\u001b[1;33m)\u001b[0m \u001b[1;32mand\u001b[0m \u001b[0misinstance\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0my\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mops\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mTensor\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 850\u001b[1;33m \u001b[1;32mreturn\u001b[0m \u001b[0mfunc\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mx\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0my\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mname\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mname\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 851\u001b[0m \u001b[1;32melif\u001b[0m \u001b[1;32mnot\u001b[0m \u001b[0misinstance\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0my\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0msparse_tensor\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mSparseTensor\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 852\u001b[0m \u001b[1;32mtry\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
"\u001b[1;32m~\\AppData\\Local\\Continuum\\anaconda3\\lib\\site-packages\\tensorflow\\python\\ops\\math_ops.py\u001b[0m in \u001b[0;36m_mul_dispatch\u001b[1;34m(x, y, name)\u001b[0m\n\u001b[0;32m 1092\u001b[0m \u001b[0mis_tensor_y\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0misinstance\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0my\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mops\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mTensor\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 1093\u001b[0m \u001b[1;32mif\u001b[0m \u001b[0mis_tensor_y\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m-> 1094\u001b[1;33m \u001b[1;32mreturn\u001b[0m \u001b[0mgen_math_ops\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mmul\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mx\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0my\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mname\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mname\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 1095\u001b[0m \u001b[1;32melse\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 1096\u001b[0m \u001b[1;32massert\u001b[0m \u001b[0misinstance\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0my\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0msparse_tensor\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mSparseTensor\u001b[0m\u001b[1;33m)\u001b[0m \u001b[1;31m# Case: Dense * Sparse.\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
"\u001b[1;32m~\\AppData\\Local\\Continuum\\anaconda3\\lib\\site-packages\\tensorflow\\python\\ops\\gen_math_ops.py\u001b[0m in \u001b[0;36mmul\u001b[1;34m(x, y, name)\u001b[0m\n\u001b[0;32m 4956\u001b[0m \u001b[1;32melse\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 4957\u001b[0m \u001b[0mmessage\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0me\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mmessage\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m-> 4958\u001b[1;33m \u001b[0m_six\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mraise_from\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0m_core\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_status_to_exception\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0me\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mcode\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mmessage\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;32mNone\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 4959\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 4960\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n",
"\u001b[1;32m~\\AppData\\Local\\Continuum\\anaconda3\\lib\\site-packages\\six.py\u001b[0m in \u001b[0;36mraise_from\u001b[1;34m(value, from_value)\u001b[0m\n",
"\u001b[1;31mInvalidArgumentError\u001b[0m: cannot compute Mul as input #0 was expected to be a double tensor but is a float tensor [Op:Mul] name: mul/"
]
}
],
"source": [
"grad(tf.cast(4, tf.double), tf.cast(3, tf.double), tf.cast(1., tf.double))"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.5"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment