Last active
September 29, 2019 09:35
-
-
Save urigoren/9a2dbf5996bfc5c39c7e661b117f41ef to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"# Linear Regression from scratch\n", | |
"## Create synthetic data" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"import numpy as np\n", | |
"from tqdm import tqdm_notebook as tqdm\n", | |
"from matplotlib import pyplot as plt\n", | |
"%matplotlib inline" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 2, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"n_samples = 2000\n", | |
"X_ = np.random.rand(n_samples,3)\n", | |
"noise = 0.002 * np.random.rand(n_samples, 1) - 0.001\n", | |
"y_ = X_ @ np.array([[0.1,-0.5,0.75]]).T + noise" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"## PyTorch Linear Reg" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 3, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"import torch\n", | |
"from torch import nn\n", | |
"from torch.nn import functional as F\n", | |
"from torch import optim" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 4, | |
"metadata": { | |
"scrolled": true | |
}, | |
"outputs": [], | |
"source": [ | |
"X = torch.tensor(X_)\n", | |
"y = torch.tensor(y_)\n", | |
"w = torch.randn(1, 3, requires_grad=True, dtype=torch.double)\n", | |
"losses = []" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 5, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"application/vnd.jupyter.widget-view+json": { | |
"model_id": "437da563e81a48849429b7668a7ded6b", | |
"version_major": 2, | |
"version_minor": 0 | |
}, | |
"text/plain": [ | |
"HBox(children=(IntProgress(value=0, max=100000), HTML(value='')))" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"\n" | |
] | |
}, | |
{ | |
"data": { | |
"text/plain": [ | |
"tensor([[ 0.1000, -0.5000, 0.7500]], dtype=torch.float64, requires_grad=True)" | |
] | |
}, | |
"execution_count": 5, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"lr=1e-3\n", | |
"for i in tqdm(range(100000)):\n", | |
" y_hat = X @ w.t()\n", | |
" loss = F.mse_loss(y_hat, y)\n", | |
" if i % 100 == 0:\n", | |
" losses.append(loss.data.tolist())\n", | |
" loss.backward()\n", | |
" with torch.no_grad():\n", | |
" w -= w.grad * lr\n", | |
" w.grad.zero_()\n", | |
"w" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"## Plot loss" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 6, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"[<matplotlib.lines.Line2D at 0x1dca573e550>]" | |
] | |
}, | |
"execution_count": 6, | |
"metadata": {}, | |
"output_type": "execute_result" | |
}, | |
{ | |
"data": { | |
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAXcAAAD8CAYAAACMwORRAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvOIA7rQAAFtBJREFUeJzt3XuMpXd93/H358zsxTe82Du0xvZm7cQQnBQKHRwTgnAaLotb2apEWm+TQChkVSkkaYPSGqW1UyKlIlSBIpnLijhWUWsXEwQra8FNgdRSA8ZjQVxfcFhsgifG2bGNTXzd27d/nGfWs7PntrNnPPscv1/S0ZznOT+f833msT77m9/5Pc8vVYUkabJ01roASdL4Ge6SNIEMd0maQIa7JE0gw12SJpDhLkkTyHCXpAlkuEvSBDLcJWkCTa/VB2/evLm2bt26Vh8vSa10++23P1xVM8ParVm4b926lbm5ubX6eElqpSR/PUo7h2UkaQIZ7pI0gQx3SZpAQ8M9ybVJ9ia5c0CbS5J8K8ldSf7PeEuUJB2rUXru1wHb+r2YZBPwMeCyqvop4BfHU5okaaWGhntV3QI8OqDJvwQ+V1Xfb9rvHVNtkqQVGseY+8uAFyf58yS3J3nHGN5TknQcxhHu08A/Av4J8FbgPyZ5Wa+GSXYkmUsyt7CwsKIPu/ehv+OP/te9PPzEsysuWJIm3TjCfR74UlU9WVUPA7cAr+rVsKp2VtVsVc3OzAy9wKqnPXuf4KNf2cOjT+5becWSNOHGEe5fAN6QZDrJycDPAPeM4X176qT785ALe0tSX0NvP5DkeuASYHOSeeBqYB1AVX2iqu5J8iXgDuAQ8Kmq6jtt8nhlMdwPrdYnSFL7DQ33qto+QpsPAR8aS0VDpEn3wp67JPXTuitUO4vhbrZLUl+tC/dmVMYxd0kaoHXh3mkqPmS2S1JfrQv3w2Pu9twlqa/WhfvimLs9d0nqr3Xhvjjmbs9dkvprXbgfni2zxnVI0omsheHe/XnIcRlJ6qt14c7h2w+sbRmSdCJrXbh3vEJVkoZqb7ib7ZLUV+vCPd4VUpKGal24L36harZLUn+tC/ccvojJdJekftoX7s1Ps12S+mtduHfsuUvSUEPDPcm1SfYmGbi6UpLXJjmY5O3jK+9ozpaRpOFG6blfB2wb1CDJFPBB4OYx1DSQs2Ukabih4V5VtwCPDmn2G8CfAnvHUdQg8QpVSRrquMfck5wN/DPgE8dfznCLwzLeOkyS+hvHF6ofAf59VR0c1jDJjiRzSeYWFhZW9GHez12Shpsew3vMAjc08883A5cmOVBVn1/esKp2AjsBZmdnVxTPjrlL0nDHHe5Vdd7i8yTXATf1CvZx8QpVSRpuaLgnuR64BNicZB64GlgHUFXPyzj7snoAe+6SNMjQcK+q7aO+WVX96nFVMwKvUJWk4Vp7har3c5ek/lob7ocOrXEhknQCa124O1tGkoZrbbgb7ZLUX+vC/bkbhxnvktRPa8PdK1Qlqb/Whbtj7pI0XGvD3WyXpP5aF+6OuUvScK0L98UrVB1zl6T+Whfu9twlabjWhrs9d0nqr3XhjrNlJGmo1oX74v3cJUn9tTDcvZ+7JA3TunB/7iKmta1Dkk5kQ8M9ybVJ9ia5s8/rv5TkjubxF0leNf4yn/PcbJnV/BRJardReu7XAdsGvH4/8MaqeiXw+zQLYK8Wbz8gScONsszeLUm2Dnj9L5Zsfh045/jL6s957pI03LjH3N8NfHHM73kEr1CVpOGG9txHleTn6Yb7zw1oswPYAbBly5YVfY5j7pI03Fh67kleCXwKuLyqHunXrqp2VtVsVc3OzMys8LO6Px1zl6T+jjvck2wBPgf8SlX91fGXNPTzAMfcJWmQocMySa4HLgE2J5kHrgbWAVTVJ4CrgDOBjzXBe6CqZlerYOhepWq0S1J/o8yW2T7k9fcA7xlbRSPoJA7LSNIArbtCFbrj7s6WkaT+WhrucbaMJA3QynDvxC9UJWmQVoZ7cMxdkgZpZbh3e+5rXYUknbhaGu7xC1VJGqCV4d6dLWO6S1I/LQ33+IWqJA3QynD3ClVJGqyl4e5sGUkapJXh7hWqkjRYS8PdK1QlaZBWhrtXqErSYK0Md69QlaTBWhnuXqEqSYO1MtzjFaqSNNDQcE9ybZK9Se7s83qSfDTJniR3JHnN+Mtc/pmOuUvSIKP03K8Dtg14/W3ABc1jB/Dx4y9rsE7iRUySNMDQcK+qW4BHBzS5HPhv1fV1YFOSs8ZVYC8d7y0jSQONY8z9bOCBJdvzzb6jJNmRZC7J3MLCwoo/0LtCStJg4wj39NjXM3qramdVzVbV7MzMzHF9oj13SepvHOE+D5y7ZPsc4MExvG9fnXjnMEkaZBzhvgt4RzNr5mLg8ar6wRjety/H3CVpsOlhDZJcD1wCbE4yD1wNrAOoqk8Au4FLgT3AU8C7VqvYwzV5haokDTQ03Ktq+5DXC/j1sVU0gniFqiQN1MorVJ0tI0mDtTLcvUJVkgZrZbh7haokDdbScHe2jCQN0spwxzF3SRqoleHuSkySNFhLw901VCVpkJaGu2PukjRIK8PdK1QlabB2hrtXqErSQK0Md8fcJWmwVoZ7HHOXpIFaGe5eoSpJg7Uy3BM46FVMktRXK8N9quNsGUkapJ3hnthzl6QBRgr3JNuS3JtkT5Ire7y+JclXk3wzyR1JLh1/qc/pdAx3SRpkaLgnmQKuAd4GXAhsT3Lhsmb/AfhMVb0auAL42LgLXWracJekgUbpuV8E7Kmq+6pqH3ADcPmyNgW8qHl+OvDg+Eo8WqcTDjrmLkl9jRLuZwMPLNmeb/Yt9XvALzcLaO8GfqPXGyXZkWQuydzCwsIKyu2aSjhkz12S+hol3NNj3/Jk3Q5cV1XnAJcCn05y1HtX1c6qmq2q2ZmZmWOvtjFtz12SBhol3OeBc5dsn8PRwy7vBj4DUFVfAzYCm8dRYC+dTjh40HCXpH5GCffbgAuSnJdkPd0vTHcta/N94BcAkryCbrivfNxliKnYc5ekQYaGe1UdAN4L3AzcQ3dWzF1JPpDksqbZ+4BfS/KXwPXAr9YqLpXUnQq5Wu8uSe03PUqjqtpN94vSpfuuWvL8buD14y2tv2mvUJWkgdp5hWonHLDrLkl9tTLcOwnOhJSk/loZ7lMd7wopSYO0NNw7zpaRpAFaGu723CVpkHaGu7f8laSBWhnunU73jgjeX0aSemtluE834e64uyT11spwX+y5OzQjSb21MtynYrhL0iDtDHeHZSRpoFaHu1+oSlJvrQ73A4a7JPXUynDvxJ67JA3SynB3zF2SBmt3uNtzl6SeRgr3JNuS3JtkT5Ir+7T550nuTnJXkv8x3jKP5FRISRps6EpMSaaAa4A3010s+7Yku5rVlxbbXAC8H3h9Vf0wyUtWq2Cw5y5Jw4zSc78I2FNV91XVPuAG4PJlbX4NuKaqfghQVXvHW+aRDt9bxjF3SepplHA/G3hgyfZ8s2+plwEvS/J/k3w9ybZeb5RkR5K5JHMLCwsrq5gl95ZxpT1J6mmUcE+Pfcu7zNPABcAlwHbgU0k2HfUfVe2sqtmqmp2ZmTnWWg9bnAp54JDpLkm9jBLu88C5S7bPAR7s0eYLVbW/qu4H7qUb9qviuStUV+sTJKndRgn324ALkpyXZD1wBbBrWZvPAz8PkGQz3WGa+8ZZ6FJTTdXOc5ek3oaGe1UdAN4L3AzcA3ymqu5K8oEklzXNbgYeSXI38FXgd6rqkdUqeqrTLdvZMpLU29CpkABVtRvYvWzfVUueF/DbzWPVOc9dkgZr5RWqncVhGcNdknpqZbgv9tyd5y5JvbUy3KenHJaRpEFaGe4dx9wlaaBWhrv3lpGkwVoZ7od77o65S1JPrQz3xTF3V2KSpN5aGe5TcQ1VSRqkleHuLX8labBWhrtXqErSYO0M947DMpI0SCvD3YuYJGmwVob7uuaev/tdikmSemp1uO87YLhLUi+tDPcN04s9d4dlJKmXkcI9ybYk9ybZk+TKAe3enqSSzI6vxKM5LCNJgw0N9yRTwDXA24ALge1JLuzR7jTgN4Fbx13kclOd0InhLkn9jNJzvwjYU1X3VdU+4Abg8h7tfh/4Q+CZMdbX17qpjmPuktTHKOF+NvDAku35Zt9hSV4NnFtVN42xtoHWT3fYZ89dknoaJdzTY9/hbzKTdIAPA+8b+kbJjiRzSeYWFhZGr7KH9VMdh2UkqY9Rwn0eOHfJ9jnAg0u2TwN+GvjzJN8DLgZ29fpStap2VtVsVc3OzMysvGq6wzL7DzhbRpJ6GSXcbwMuSHJekvXAFcCuxRer6vGq2lxVW6tqK/B14LKqmluVihvrpuOwjCT1MTTcq+oA8F7gZuAe4DNVdVeSDyS5bLUL7GfdlGPuktTP9CiNqmo3sHvZvqv6tL3k+Msabv1Uh/3OlpGknlp5hSp0Z8v4haok9dbacHdYRpL6a3G4x9kyktRHa8N9/fSUPXdJ6qO94T4Vx9wlqY/Whrv3lpGk/lod7vbcJam3loe7X6hKUi+tDXfvCilJ/bU33P1CVZL6am24+4WqJPXX3nD39gOS1Fdrw33DdPcL1YOH/FJVkpZrbbifsr57Q8un9x9c40ok6cTT2nA/af0UAE89e2CNK5GkE09rw/2UDd1wf3KfPXdJWm6kcE+yLcm9SfYkubLH67+d5O4kdyT5cpIfG3+pRzq5GZZ50p67JB1laLgnmQKuAd4GXAhsT3LhsmbfBGar6pXAZ4E/HHehyy2OuT9lz12SjjJKz/0iYE9V3VdV+4AbgMuXNqiqr1bVU83m14Fzxlvm0U4+PCxjz12Slhsl3M8GHliyPd/s6+fdwBePp6hRHO65P2vPXZKWG2WB7PTY13NyeZJfBmaBN/Z5fQewA2DLli0jltjbyYuzZey5S9JRRum5zwPnLtk+B3hweaMkbwJ+F7isqp7t9UZVtbOqZqtqdmZmZiX1HvZcuNtzl6TlRgn324ALkpyXZD1wBbBraYMkrwY+STfY946/zKOdsqGZLWPPXZKOMjTcq+oA8F7gZuAe4DNVdVeSDyS5rGn2IeBU4MYk30qyq8/bjc2G6Q6dOOYuSb2MMuZOVe0Gdi/bd9WS528ac11DJeGU9dP23CWph9ZeoQrdoZknnjHcJWm5Vof7Gaes59En9611GZJ0wml1uJ956noefqLnxBxJekFrdbhvPnUDDz9hz12Slmt5uK/nkSefpcoFOyRpqVaH+5mnbuCZ/Ye87a8kLdPucD9lPQCPOO4uSUdodbi/5EUbAXjo8WfWuBJJOrG0OtzP33wKAN9deHKNK5GkE0urw/3sTSexcV2HPXufWOtSJOmE0upw73TCj8+cyp4Fw12Slmp1uANceNaLuGP+MQ4ecjqkJC1qfbi/8eUzPPbUfr71wGNrXYoknTBaH+5v+IkZ1k93uHHugeGNJekFovXhfvrJ67jitefy2dvn+cb9j651OZJ0Qmh9uAO87y0v59wzTuZX/vhWdt7yXZ72ilVJL3AjhXuSbUnuTbInyZU9Xt+Q5H82r9+aZOu4Cx3k9JPWceO/fh0/++Nn8ge7v81Ff/C/+Z0b/5I/u/tvefzp/c9nKZJ0Qhi6ElOSKeAa4M10F8u+Lcmuqrp7SbN3Az+sqp9IcgXwQeBfrEbB/Ww+dQN/8q6LuO17j3LDNx7gi3c+xI23z9MJ/NRLT+cfnHM6r/j7p/GTZ72I8zefwhmnrCfJ81miJD1vMuyOikleB/xeVb212X4/QFX95yVtbm7afC3JNPAQMFMD3nx2drbm5ubGcAi9PXvgIN/8/mN87buPcOv9j3D3gz/iR0tWbdq4rsNLN53E2ZtOYua0DWw6aT2bTl7XPNZz2sZpNk5PsXFdh43rpjhp3RQb13W3N0xP0enAVMJUJ/4jIel5k+T2qpod1m6UNVTPBpZORZkHfqZfm6o6kORx4Ezg4dHKHb8N01NcfP6ZXHz+mTR18YPHn+HbD/2I7z38FA8+9jR/0zzuW3iSx57at+K7S3YC050OnU7zMzA91aGT0AkkELr/ACz+OxA44h+Fw/sPv54j2tK0P/xf9HkfSSe+K157Lu95w/mr+hmjhHuv5FjeIx+lDUl2ADsAtmzZMsJHj08SXrrpJF666aS+bfYdOMTjT+/nsaf28XfPHuCZ/Qd5dv8hnt5/kGf2H+SZ/Ye6+w4c4lAVBw8VBw4VhxZ/NvsWHwcOFVXF4t8v1fxKqp775XSfH25w+MfiHz1Htlt8vuQ1r92SWmfzqRtW/TNGCfd54Nwl2+cAD/ZpM98My5wOHDUvsap2AjuhOyyzkoJX0/rpDjOnbWDmtNX/xUvSahpltsxtwAVJzkuyHrgC2LWszS7gnc3ztwNfGTTeLklaXUN77s0Y+nuBm4Ep4NqquivJB4C5qtoF/DHw6SR76PbYr1jNoiVJg40yLENV7QZ2L9t31ZLnzwC/ON7SJEkrNRFXqEqSjmS4S9IEMtwlaQIZ7pI0gQx3SZpAQ+8ts2ofnCwAf73C/3wza3hrgzXiMb8weMwvDMdzzD9WVTPDGq1ZuB+PJHOj3DhnknjMLwwe8wvD83HMDstI0gQy3CVpArU13HeudQFrwGN+YfCYXxhW/ZhbOeYuSRqsrT13SdIArQv3YYt1t1WSc5N8Nck9Se5K8lvN/jOS/FmS7zQ/X9zsT5KPNr+HO5K8Zm2PYGWSTCX5ZpKbmu3zmkXWv9Msur6+2b+mi7CPU5JNST6b5NvN+X7dJJ/nJP+2+X/6ziTXJ9k4iec5ybVJ9ia5c8m+Yz6vSd7ZtP9Oknf2+qxRtCrclyzW/TbgQmB7kgvXtqqxOQC8r6peAVwM/HpzbFcCX66qC4AvN9vQ/R1c0Dx2AB9//ksei98C7lmy/UHgw83x/pDu4uuwZBF24MNNu7b6r8CXquongVfRPf6JPM9JzgZ+E5itqp+me9vwK5jM83wdsG3ZvmM6r0nOAK6mu5TpRcDVi/8gHLPuMnDteACvA25esv1+4P1rXdcqHesXgDcD9wJnNfvOAu5tnn8S2L6k/eF2bXnQXdXry8A/Bm6iu1zjw8D08vNNdz2B1zXPp5t2WetjWMExvwi4f3ntk3qeeW595TOa83YT8NZJPc/AVuDOlZ5XYDvwySX7j2h3LI9W9dzpvVj32WtUy6pp/hR9NXAr8Peq6gcAzc+XNM0m4XfxEeDfAYea7TOBx6rqQLO99JiOWIQdWFyEvW3OBxaAP2mGoz6V5BQm9DxX1d8A/wX4PvADuuftdib/PC861vM6tvPdtnAfaSHuNktyKvCnwL+pqh8NatpjX2t+F0n+KbC3qm5furtH0xrhtTaZBl4DfLyqXg08yXN/qvfS6uNuhhQuB84DXgqcQndIYrlJO8/D9DvOsR1/28J9lMW6WyvJOrrB/t+r6nPN7r9Nclbz+lnA3mZ/238XrwcuS/I94Aa6QzMfATY1i6zDkcd0+HgHLcLeAvPAfFXd2mx/lm7YT+p5fhNwf1UtVNV+4HPAzzL553nRsZ7XsZ3vtoX7KIt1t1KS0F2L9p6q+qMlLy1dfPyddMfiF/e/o/nW/WLg8cU//9qgqt5fVedU1Va65/ErVfVLwFfpLrIORx9v6xdhr6qHgAeSvLzZ9QvA3UzoeaY7HHNxkpOb/8cXj3eiz/MSx3pebwbekuTFzV89b2n2Hbu1/gJiBV9YXAr8FfBd4HfXup4xHtfP0f3z6w7gW83jUrrjjV8GvtP8PKNpH7ozh74L/D+6sxHW/DhWeOyXADc1z88HvgHsAW4ENjT7Nzbbe5rXz1/ruo/jeP8hMNec688DL57k8wz8J+DbwJ3Ap4ENk3iegevpfq+wn24P/N0rOa/Av2qOfw/wrpXW4xWqkjSB2jYsI0kageEuSRPIcJekCWS4S9IEMtwlaQIZ7pI0gQx3SZpAhrskTaD/D1bhiKAKH7JFAAAAAElFTkSuQmCC\n", | |
"text/plain": [ | |
"<Figure size 432x288 with 1 Axes>" | |
] | |
}, | |
"metadata": { | |
"needs_background": "light" | |
}, | |
"output_type": "display_data" | |
} | |
], | |
"source": [ | |
"i,l = list(zip(*enumerate(losses)))\n", | |
"plt.plot(i, l)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"## Sanity check" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 7, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"0.13906887610688975\n" | |
] | |
} | |
], | |
"source": [ | |
"best_w = torch.Tensor([0.1,-0.5,0.75]).double()\n", | |
"loss = F.mse_loss(X @ best_w, y)\n", | |
"print (loss.data.tolist())" | |
] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 3", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.6.7" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 2 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment