Last active
March 20, 2018 19:17
-
-
Save fehiepsi/e417e549bddffbd223839766a1267eb5 to your computer and use it in GitHub Desktop.
GP timeseries
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"%matplotlib inline\n", | |
"import matplotlib.pyplot as plt\n", | |
"import torch\n", | |
"import pyro\n", | |
"import pyro.contrib.gp as gp" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 2, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAXQAAAD8CAYAAABn919SAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMS4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvNQv5yAAAIABJREFUeJzt3Xl4VPW9x/H3Nys7ARKQPawiKpthcQVxqbVcrEsVtyKuUK1r63qrdrFXbasVq21VFFSqIEVFsbaIcq2tAmELS5SwL4khLIEECFnmd/+YwZtiQoZkZs7M5PN6nnlm5pwzz3w4mXw4+c1ZzDmHiIjEvgSvA4iISGio0EVE4oQKXUQkTqjQRUTihApdRCROqNBFROKECl1EJE6o0EVE4oQKXUQkTiRF8s3S09NdZmZmJN9SRCTmLVmyZKdzLqOu5SJa6JmZmWRnZ0fyLUVEYp6ZbQ5mOQ25iIjECRW6iEicUKGLiMQJFbqISJxQoYuIxAkVuohInFChi4jECRW6iEgY5Rcf5BfvrWHvwYqwv1dEDywSEWkstu05wPML1vNW9lYATuvVjnP7dwjre6rQRURCaOvuAzy/YB2zlmzDMK4Y2pVJo3rTOa1p2N9bhS4iEgKbd+3nuU/WMXvpdhLMuHJYNyaO7EWnCBT5YSp0EZEG2LhzP3/4eB3vLN9OYoJxzYjuTBzZi+NaN4l4FhW6iEg9rC8q5blAkScnJjD+1EwmjuxJ+1aRL/LDVOgiIsdg3Y4Snv14He+tyCclKYEbzujBTWf1pH1L74r8MBW6iEgQ1haWMHl+HnNXFtAkKZGbzuzJTWf1JL1FqtfRvqFCFxE5ir0HK3jw7ZXMzSmgeUoiE0f24sYzetAuior8sKAL3cwSgWxgu3NujJlNBUYCewOLXOecWx76iCIi3vD5HPfMXMGCr3Zw69m9uPGMnrRpnuJ1rFodyxb6HUAu0KratJ8652aFNpKISHT486cb+Ci3kIfH9Of6M3p4HadOQR36b2ZdgO8BL4U3johIdPh8/S5+8/cv+d6Ajkw4PdPrOEEJ9lwuvwfuBXxHTH/MzHLM7Gkzi74BJRGReijcV8aP31hGZnpznrh0AGbmdaSg1FnoZjYG2OGcW3LErAeAfsBQoC1wXy2vv9nMss0su6ioqKF5RUTCqqLKx21/Wcr+Q5X86ZpTaJEaO/uOBLOFfjow1sw2AW8Co83sdedcgfM7BLwCDKvpxc65F5xzWc65rIyMjJAFFxEJhyc//JLFm/bw+KUn07dDS6/jHJM6C90594BzrotzLhMYB3zsnLvGzDoCmP9vke8Dq8KaVEQkzD5cVcCL/9zItSO6c9Ggzl7HOWYN+VtiupllAAYsByaGJpKISORtKCrlJ2/lMLBrGv895gSv49TLMRW6c24BsCDweHQY8oiIRNzB8ip+NH0pyYnG81cPITUp0etI9RI7o/0iImHgnOOhd1byVWEJUycMi8h5y8NFl6ATkUbtjUVbmb10O7eP7sPIvrG944YKXUQarZxtxTw6ZzVn9knn9nP6eB2nwVToItIoFR8oZ9LrS0lvkcIz4waTmBAbBw8djcbQRaTR8fkcd81Yzo6SMmbecipto/iEW8dCW+gi0ug8v2Adn3xVxM/G9GdwtzZexwkZFbqINCqf5e3kqXlrGTuwE9eO6O51nJBSoYtIo1Gw9yC3v7mMnhkt+J9LTo6Zk24FS4UuIo1CeaWPW6cv5VBFFX+65hSax9BJt4IVf/8iEZEa/M/fclm6pZg/XDWY3u1beB0nLLSFLiJx7/2cfF751yauOy2TMQM6eR0nbFToIhLXvvq6hPtm5TCkWxoPXhibJ90KlgpdROJW4b4yJryyiGapSTx39RBSkuK78uL7XycijVbpoUque2Uxew9W8Mp1Q+nYOnZPuhUsfSkqInGnosrHpNeXsLawhCnjszipc2uvI0WEttBFJK4453hw9kr+mbeTX198EqOOb+91pIhRoYtIXJk8fx1vLdnG7aN7c8XQbl7HiaigC93MEs1smZm9H3jew8wWmlmemc0ws/g4u42IxKy3srfy9EdruWRIZ+46r6/XcSLuWLbQ7wByqz1/AnjaOdcH2APcEMpgIiLH4p95RTwweyVn9E7n8UsGxN1h/cEIqtDNrAvwPeClwHMDRgOzAotMA74fjoAiInVZk7+PSa8vpXf7Fjx/TfzvnlibYP/VvwfuBXyB5+2AYudcZeD5NqBzTS80s5vNLNvMsouKihoUVkTkSPnFB5kwdREtUpN4ZcJQWjVJ9jqSZ+osdDMbA+xwzi2pPrmGRV1Nr3fOveCcy3LOZWVkxPb1+kQkuuwrq2DCK4s5cKiKVyY0jn3NjyaY/dBPB8aa2YVAE6AV/i32NDNLCmyldwHywxdTROQ/lVf69zVfX1TK1AnDOKFjK68jea7OLXTn3APOuS7OuUxgHPCxc+5q4BPgssBi44F3w5ZSRKQa5xz3/zWHf63bxeOXDuCMPuleR4oKDfnm4D7gbjNbh39MfUpoIomIHN1T89Yye9l27j6vL5ed0sXrOFHjmA79d84tABYEHm8AhoU+kohI7d5YtIVnP17HFVld+fHo3l7HiSqNc98eEYlJn3y1g/9+ZxVn9c3gVxef1Cj3NT8aFbqIxIRV2/dy6/Sl9DuuJc9fPYTkRNXXkbRGRCTqbd19gAlTF9OmWQovXzeUFnF4PdBQ0FoRkai1e385izbu4rf/WEtZRRXTbxxOh1ZNvI4VtVToIhI1ikoOsWjjbr7YsIuFG3extrAUgJZNknjh2iz6dmjpccLopkIXkaNyzrG+qJSWTZJp1zyFpBCOXRfuKwuU924WbtjF+qL9ADRLSSQrsy0XDerMiJ5tOblzWqM9P8uxUKGLSK1Wbd/LI3NWs2TzHgASDNJbpNKhVRM6tEqlfasmdGjpf9yhVRMyWvrv2zVPISHh23ug5BcfZOHGXSzcsJuFG3ezcae/wFukJjE0sw0/yOrK8B5tOalza33pWQ8qdBH5luID5fz2H1/xl4VbaNMshYfH9CclKYEd+8oo3HeIwpIytheXsWxLMbv2l3/r9UkJRkbLw4WfStOURJZtKWbL7gMAtGqSxLAebblqWDeG92xL/46tQrrl31ip0EXkG1U+x4zFW/nN379kX1kl40/L5M5z+9K6ae1nMCyv9FFUeojCfWX/X/iB+x0lZWzatZ+SskpO7tya8adlMrxHW07o2IrEGrbgpWFU6CICwLIte3hkzmpytu1lWI+2/HzsiUGd8ColKYHOaU3pnNa4z3QYDVToIo3cztJDPPnhl8zM3kaHVqk8M24QYwd20lGYMUiFLtJIVVb5eO2LzTw1z7+P98SRvfjx6N4010E7MUs/OZFG6IsNu3h0zmq+/LqEM/uk8+jYE+mV0cLrWNJAKnSRRuTrvWX8+oNc5qzIp3NaU/587Smc37+DhlfihApdpBEor/Tx8r82Mnl+HpU+xx3n9GHiyF40TUn0OpqEkApdJM5t3LmfG6YtZkPRfs49oQMPj+lPt3bNvI4lYVBnoZtZE+BTIDWw/Czn3CNmNhUYCewNLHqdc255uIKKyLErq6hi0utL2LO/nFeuG8rZ/dp7HUnCKJgt9EPAaOdcqZklA5+Z2d8C837qnJsVvngi0hA/f8//xefUCUMZdbzKPN7VWejOOQeUBp4mB24unKFEpOHeXb6dNxZtZdKoXirzRiKokyeYWaKZLQd2APOccwsDsx4zsxwze9rMUmt57c1mlm1m2UVFRSGKLSJHs76olAdnrySrexvuOa+v13EkQoIqdOdclXNuENAFGGZmJwEPAP2AoUBb4L5aXvuCcy7LOZeVkZERotgiUpuyiipunb6UlKQEnr1qsE561Ygc00/aOVcMLAAucM4VOL9DwCvAsDDkE5Fj9Iv31/Dl1yU8dcUgOrbW+VUakzoL3cwyzCwt8LgpcC7wpZl1DEwz4PvAqnAGFZG6zVmRz18WbmHiyF6crXHzRieYvVw6AtPMLBH/fwAznXPvm9nHZpYBGLAcmBjGnCJSh4079/PAX3M4pXsb7jlf4+aNUTB7ueQAg2uYPjosiUTkmJVVVPGj6UtJTkrg2SsH62o/jZSOFBWJA7+au4bcgn28fF0WnXRe8kZL/42LxLj3VuTz+hdbuOWsnozu18HrOOIhFbpIDNu0cz8PzF7JkG5p/OQ7x3sdRzymQheJUWUVVdz6l6UkJhjPXjVE4+aiMXSRWPXY3FxW5+/jpR9m6XqeAmgLXSQmzc0p4LUvNnPTmT04t7/GzcVPhS4SYzbv2s99f81hcLc07r2gn9dxJIqo0EViyKHKauPm2t9cjqAxdJEY8uu5uazavo8Xf5hFlza66pD8J/33LhIjPlhZwLTPN3PDGT04T+PmUgMVukgM2LLrAPfNymFg1zTu07i51EKFLhLlDo+bm8EfrhxMSpJ+baVmGkMXiWJVPsfdM1ewcvte/nztKXRtq3FzqZ3+qxeJUs45Hnp7JXNzCnjgu/34zonHeR1JopwKXSQKOed4bG4uby7eym1n9+aWkb28jiQxQIUuEoUmz1/HS59t5LrTMnWxCglaMJega2Jmi8xshZmtNrOfB6b3MLOFZpZnZjPMLCX8cUXi38ufbeTpj9Zy6ZAuPDymP/6rPIrULZgt9EPAaOfcQGAQcIGZjQCeAJ52zvUB9gA3hC+mSOMwc/FWfvH+Gi448TieuPRkEhJU5hK8Ogvd+ZUGniYHbg4YDcwKTJ+G/0LRIlJPH6ws4P7ZOZzZJ51nrhxEkg7rl2MU1CfGzBLNbDmwA5gHrAeKnXOVgUW2AZ1ree3NZpZtZtlFRUWhyCwSdxZ8tYM73lzGkG5t+PO1p5CalOh1JIlBQRW6c67KOTcI6AIMA06oabFaXvuCcy7LOZeVkZFR/6QicWrRxt1MfH0Jfdq3ZMp1Q2mWosNDpH6O6W8651wxsAAYAaSZ2eFPXhcgP7TRROLfqu17uWHqYjqlNeXVG4bRummy15EkhgWzl0uGmaUFHjcFzgVygU+AywKLjQfeDVdIkXi0bkcJP3x5Ea2aJjP9xuGkt0j1OpLEuGD+tusITDOzRPz/Acx0zr1vZmuAN83sV8AyYEoYc4rEla27D3D1SwtJTDCm3zicjq11CTlpuDoL3TmXAwyuYfoG/OPpInIMCveVcfVLCymr8DHjlhFkpjf3OpLECe0XJRJBe/aXc+2UhewqPcS064fR77hWXkeSOKKv00UipKSsgvGvLGLTrgNMmzCMQV3TvI4kcUZb6CIRUFZRxY3TslmTv48/Xj2EU3u18zqSxCFtoYuEkc/n2F58kEfmrGbRpt08M24w55ygy8dJeKjQRUKgyufYuvsAawtLyNtRyrodpeTtKGHdjlLKKnwA/Prikxk7sJPHSSWeqdAl7s1Zkc/MxVtplpJIWrNk0pql0LppMmnNkmnTLIW0psm0DkxPa5pMs5TEWs9wWFnlY/PuA+QVlpBXWEreDv9tfVEp5ZW+b5br1LoJvTu05Orh7ejTvgUDuqTRv5O+AJXwUqFL3Co+UM7P3l3NeyvyyWzXjCbJieRsq6D4YPk3W801SUlM8Bd8oPRbN00hOdHYULSfDTtLqaj6/7NcdGnTlD7tW3Bmn3T6tG9Bnw4t6ZXRnJZNdMSnRJ4KXeLSP/OK+MlbK9hVWs5Pzu/LxJG9/uPshWUVVew9WMGeA+UUH6ig+EAFew8GHh+soLja9O3FBzlUWUXP9OaMPqG9v7jbt6RX++Y674pEFX0aJa4cLK/i8b/lMu3zzfRu34Ip44dyUufW31quSXIiTZIT6dCqiQcpRcJDhS5xY8XWYu6auZwNRfu5/vQe3HvB8TRJ1mlopfFQoUvMq6zy8dwn65n8cR7tW6Yy/cbhnN473etYIhGnQpeYtqGolLtmrmDF1mIuHtyZR8eeqFPQSqOlQpeY5Jzj9S8289gHuaQmJfKHqwYzZoD28ZbGTYUuMadwXxk/nZXDp2uLOKtvBr+5bIC+3BRBhS4xZm5OAQ+9s5Kyiip+edGJXDOie60HAYk0Nip0iQl7D1bwyLureGd5PgO7pvH05QPpmdHC61giUaXOQjezrsCrwHGAD3jBOfeMmT0K3AQUBRZ90Dn3QbiCSuO1On8vN07LZkfJIe46ty+3nv2fBwmJiF8wW+iVwD3OuaVm1hJYYmbzAvOeds79NnzxpLErKatg0utLcQ5mTzqNgTqHuEitgrkEXQFQEHhcYma5QOdwBxNxzvHQ26vYXnyQGTePUJmL1OGY/m41s0z81xddGJh0m5nlmNnLZtYmxNmkkXtryTbmrMjnrnP7kJXZ1us4IlEv6EI3sxbAX4E7nXP7gD8CvYBB+Lfgf1fL6242s2wzyy4qKqppEZFvWbejhEfeXc2pPdsxaVRvr+OIxISgCt3MkvGX+XTn3GwA51yhc67KOecDXgSG1fRa59wLzrks51xWRkZGqHJLHCurqOK2vyyjaUoivx83iMQE7ZYoEow6C938O/lOAXKdc09Vm96x2mIXA6tCH08ao8fm5vLl1yX87gcDdcCQyDEIZi+X04FrgZVmtjww7UHgSjMbBDhgE3BLWBJKo/LhqgJe+2IzN53Zg7P7tfc6jkhMCWYvl8+Amv7m1T7nElLb9hzg3lk5DOjSmp9+p5/XcURijo7OkKhQWeXjjjeX43Pw7JWDSUnSR1PkWOnQf4kKv/8ojyWb9zD5ysF0b9fc6zgiMUmbQeK5f63byXML1nFFVlfGDtQpcEXqS4UuntpZeog7ZyynZ3pzHhnb3+s4IjFNQy7iGZ/Pcc/MFew9WMGr1w+jWYo+jiINoS108cxLn23gf9cW8bMx/TmhYyuv44jEPBW6eGL51mKe/PArLjjxOK4Z3s3rOCJxQYUuEbevrILb31hGh1ZNeOLSAbrikEiIaNBSIso5x4OzV7K9+CAzbxlB62bJXkcSiRvaQpeImpm9lfdzCrj7vL6c0l2nxBUJJRW6RExeYQmPzFnNGb3TmTSyl9dxROKOCl0i4vApcZunJPHU5QNJ0ClxRUJOY+gSEb98fw1fFZYwdcJQ2uuUuCJhoS10CbtZS7YxfeEWbjmrJ6OO1ylxRcJFhS5hNTN7Kz+dtYLTe7fjnvOP9zqOSFxToUvY/GXhFu6dlcMZvdOZMn6oTokrEmbBXIKuq5l9Yma5ZrbazO4ITG9rZvPMLC9w3yb8cSVWvPr5Jh58eyVnH5/Biz/MoklyoteRROJeMJtMlcA9zrkTgBHArWbWH7gfmO+c6wPMDzwX4aV/buDhd1dzXv8O/OnaU1TmIhFSZ6E75wqcc0sDj0uAXKAzcBEwLbDYNOD74QopseNP/7ueX83N5cKTj+P5q4eQmqQyF4mUY9pt0cwygcHAQqCDc64A/KVvZtp9oZF7dn4ev5u3lrEDO/HU5QNJStSYuUgkBf0bZ2YtgL8Cdzrn9h3D6242s2wzyy4qKqpPRolyzjme+sdX/G7eWi4Z3JmnrxikMhfxQFC/dWaWjL/MpzvnZgcmF5pZx8D8jsCOml7rnHvBOZflnMvKyMgIRWaJIs45nvz7V0z+eB2XZ3XhNz8YSKKOAhXxRDB7uRgwBch1zj1VbdYcYHzg8Xjg3dDHk2jmnOPXH+TyxwXruWp4Nx6/ZIDKXMRDwYyhnw5cC6w0s+WBaQ8CjwMzzewGYAvwg/BElGjknOPn761h6r83cd1pmTzyX/11XnMRj9VZ6M65z4DaflPPCW0ciQU+n+Nn765i+sIt3HhGDx763gkqc5EooJNzyTGp8jkemJ3DzOxtTBrVi3u/c7zKXCRKqNAlaFU+x0/fWsHsZdu5/Zw+3HVuH5W5SBRRoUtQKqt83DVzBe+tyOfu8/py+zl9vI4kIkdQoctRVVT52FVazi/eX80HK7/mvgv6MWmUrjYkEo1U6DGgyufYuHM/SQlGSlICqUkJpBy+JSYc87BHeaWPXfsPsbOknJ2lhygqPcTOUv/zotJD7CwJPC89xJ4DFd+87r+/dwI3ntkz1P88EQkRFXqUK9xXxqTXl7B0S3Gty6QkJZCa+P8lf2ThpyYl4nOOXfv9BV5craSra56SSEbLVNJbpNIrowXDe7YlvYX/eb/jWpKVqYs6i0QzFXoUW7xpN5NeX8qB8kp+NqY/bZolU17p41Clj/JKH+VV/seHKqv8zwO36vPLA/MNo0/7Fpzasx3pLVIDxZ1CestUMgKl3TRFJ9ISiWUq9CjknOPVzzfzy/fX0KVNU6bfOJzjj2vpdSwRiXIq9ChTVlHFg2+vZPbS7ZzTrz1PXTGI1k2TvY4lIjFAhR5Ftu4+wMTXl7A6fx93ntuH20f3IUHnRhGRIKnQo8RneTv58RtLqfQ5pozP4pwTOngdSURijArdY845/vzpBp788Et6t2/Bn6/Nokd6c69jiUgMUqF7qPRQJffOWsEHK7/mewM68uSlA2ieqh+JiNSP2sMjG4pKueW1JawvKuXBC/tx05k9dV4UEWkQFboH5q0p5O4Zy0lKNF67YTin9073OpKIxAEVegT5fI7fz89j8vw8Tu7cmj9eM4QubZp5HUtE4kQwl6B72cx2mNmqatMeNbPtZrY8cLswvDFj394DFdwwbTGT5+dx2SldeGviqSpzEQmpYLbQpwJ/AF49YvrTzrnfhjxRHFpfVMr1Uxezfc9BfnnRiVwzorvGy0Uk5IK5BN2nZpYZ/ijx6UB5JTe9mk1pWSUzbhnBKd11gisRCY86h1yO4jYzywkMybQJWaI48/M5a9i4cz+TrxysMheRsKpvof8R6AUMAgqA39W2oJndbGbZZpZdVFRUz7eLTe+tyGdG9lYmjeylPVlEJOzqVejOuULnXJVzzge8CAw7yrIvOOeynHNZGRkZ9c0Zc7buPsCDs1cyuFsad53X1+s4ItII1KvQzaxjtacXA6tqW7YxqqjycfubywCYPG4wyYkNGdkSEQlOnV+KmtkbwCgg3cy2AY8Ao8xsEOCATcAtYcwYc575KI9lW4p59srBdG2rXRNFJDKC2cvlyhomTwlDlrjw7/U7eW7BOi7P6sJ/DezkdRwRaUQ0FhBCu/eXc9eM5fRIb86jY0/0Oo6INDIq9BBxznHvrBXs2V/B5HGDaZaisyqISGSp0EPk1c8381HuDu7/bj9O6tza6zgi0gip0ENgTf4+Hvsgl9H92jPh9Eyv44hII6VCb6AD5ZX8+I2lpDVN5jeXDdA5WkTEMxrobaBfvr+GDTv38/oNw2nXItXrOCLSiGkLvQHm5hTwxiId2i8i0UGFXk/b9hzg/tk5DOqqQ/tFJDqo0OuhssrHHW8uBwfPXqlD+0UkOmgMvR4mz89jyeY9PDNukA7tF5GooU3LY/T5+l08+8k6fnBKFy4a1NnrOCIi31ChH4Pd+8u5c8YyerTTof0iEn005BIk/6H9OezZX8GU8UNpnqpVJyLRRVvoQXrti818lFuoQ/tFJGqp0IOwOn8vv5qrQ/tFJLqp0OuQvWk3V7+0kDbNdGi/iES3OgvdzF42sx1mtqratLZmNs/M8gL3bcIb0xsfrirgqpcW0rZZCrMmnqZD+0UkqgWzhT4VuOCIafcD851zfYD5gedxZdq/NzFp+lJO6tSKWZNO0/7mIhL16ix059ynwO4jJl8ETAs8ngZ8P8S5POPzOf7nb7k8Mmc1553Qgek3jqBt8xSvY4mI1Km++951cM4VADjnCsysfQgzeeZQZRX3zsrh3eX5XDuiO4+OPZHEBI2Zi0hsCPvO1GZ2M3AzQLdu3cL9dvW2r6yCia8t4d/rd3HvBcczaWQvfQEqIjGlvnu5FJpZR4DA/Y7aFnTOveCcy3LOZWVkZNTz7cLr671lXP6nz1m0cTdPXT6QH43qrTIXkZhT30KfA4wPPB4PvBuaOJG3trCES57/F9v2HOSVCUO5ZEgXryOJiNRLnUMuZvYGMApIN7NtwCPA48BMM7sB2AL8IJwhw+WLDbu4+dVsmiQnMuOWEZzYSUeAikjsqrPQnXNX1jLrnBBniaj3c/K5e8YKurZtyrTrh9GljXZLFJHY1ijPMPXSPzfwq7m5DM1sw4s/zCKtmXZLFJHY16gK3edzPPZBLlM+28h3TzqOp68YRJPkRK9jiYiERKMp9LKKKu55awVzcwq47rRMfjamv/YxF5G4EvOF7pyjospRXuWjvNJHReD+8PPySh+HKn389h9fsWjjbh68sB83ndlTuyWKSNyJiUKfPD+Pd5Zv/8/CrvR9U+TBSE40nhk3SJeNE5G4FROF3r5lKv07tiIlKYGUxARSkhJIDtynVLtPTjRSkhID99XmJSXQvW1zurXTniwiEr9iotDHDevGuGHRe9oAEZFooAtciIjECRW6iEicUKGLiMQJFbqISJxQoYuIxAkVuohInFChi4jECRW6iEicMOdc5N7MrAjYXM+XpwM7Qxgn1JSvYZSvYZSv4aI5Y3fnXJ3X8IxooTeEmWU757K8zlEb5WsY5WsY5Wu4WMhYFw25iIjECRW6iEiciKVCf8HrAHVQvoZRvoZRvoaLhYxHFTNj6CIicnSxtIUuIiJHEXWFbmYXmNlXZrbOzO6vYX6qmc0IzF9oZpkRzNbVzD4xs1wzW21md9SwzCgz22tmywO3hyOVL/D+m8xsZeC9s2uYb2Y2ObD+csxsSASzHV9tvSw3s31mducRy0R0/ZnZy2a2w8xWVZvW1szmmVle4L5NLa8dH1gmz8zGRzDfb8zsy8DP720zS6vltUf9LIQx36Nmtr3az/DCWl571N/1MOabUS3bJjNbXstrw77+Qs45FzU3IBFYD/QEUoAVQP8jlvkR8KfA43HAjAjm6wgMCTxuCaytId8o4H0P1+EmIP0o8y8E/gYYMAJY6OHP+mv8+9d6tv6As4AhwKpq054E7g88vh94oobXtQU2BO7bBB63iVC+84GkwOMnasoXzGchjPkeBX4SxM//qL/r4cp3xPzfAQ97tf5CfYu2LfRhwDrn3AbnXDnwJnDREctcBEwLPJ4FnGMRuuKzc67AObc08LgEyAVi7SKlFwGvOr8vgDQz6+hBjnOA9c65+h5oFhLOuU+B3UdMrv4ZmwZ8v4aXfgdJnwtdAAADKUlEQVSY55zb7ZzbA8wDLohEPufcP5xzlYGnXwBdQv2+wapl/QUjmN/1BjtavkBvXA68Eer39Uq0FXpnYGu159v4dmF+s0zgQ70XaBeRdNUEhnoGAwtrmH2qma0ws7+Z2YkRDQYO+IeZLTGzm2uYH8w6joRx1P6L5OX6A+jgnCsA/3/iQPsalomW9Xg9/r+4alLXZyGcbgsMCb1cy5BVNKy/M4FC51xeLfO9XH/1Em2FXtOW9pG74QSzTFiZWQvgr8Cdzrl9R8xein8YYSDwLPBOJLMBpzvnhgDfBW41s7OOmB8N6y8FGAu8VcNsr9dfsKJhPT4EVALTa1mkrs9CuPwR6AUMAgrwD2scyfP1B1zJ0bfOvVp/9RZthb4N6FrteRcgv7ZlzCwJaE39/uSrFzNLxl/m051zs4+c75zb55wrDTz+AEg2s/RI5XPO5QfudwBv4//Ttrpg1nG4fRdY6pwrPHKG1+svoPDwMFTgfkcNy3i6HgNfwo4BrnaBAd8jBfFZCAvnXKFzrso55wNerOV9vV5/ScAlwIzalvFq/TVEtBX6YqCPmfUIbMWNA+Ycscwc4PAeBZcBH9f2gQ61wJjbFCDXOfdULcscd3hM38yG4V/HuyKUr7mZtTz8GP+XZ6uOWGwO8MPA3i4jgL2HhxciqNYtIy/XXzXVP2PjgXdrWObvwPlm1iYwpHB+YFrYmdkFwH3AWOfcgVqWCeazEK581b+TubiW9w3mdz2czgW+dM5tq2mml+uvQbz+VvbIG/69MNbi/wb8ocC0X+D/8AI0wf+n+jpgEdAzgtnOwP9nYQ6wPHC7EJgITAwscxuwGv+39l8Ap0UwX8/A+64IZDi8/qrnM+C5wPpdCWRF+OfbDH9Bt642zbP1h/8/lgKgAv9W4w34v5OZD+QF7tsGls0CXqr22usDn8N1wIQI5luHf/z58Gfw8F5fnYAPjvZZiFC+1wKfrRz8Jd3xyHyB59/6XY9EvsD0qYc/c9WWjfj6C/VNR4qKiMSJaBtyERGRelKhi4jECRW6iEicUKGLiMQJFbqISJxQoYuIxAkVuohInFChi4jEif8DasI6ecekKXEAAAAASUVORK5CYII=\n", | |
"text/plain": [ | |
"<matplotlib.figure.Figure at 0x7f31f50946a0>" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
} | |
], | |
"source": [ | |
"X = torch.arange(20)\n", | |
"Y = 2 * X + 8 + 2 * torch.sin(5 * X)\n", | |
"plt.plot(X.numpy(), Y.numpy());" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 3, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"# give a seed to correct tensors' type\n", | |
"kernel = gp.kernels.Periodic(1)\n", | |
"residual_model = gp.models.GPRegression(X[:1], Y[:1], kernel, noise=torch.tensor(0.01))" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 4, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"def model(x, y):\n", | |
" a = pyro.param(\"a\", torch.tensor(0.1, requires_grad=True))\n", | |
" b = pyro.param(\"b\", torch.tensor(1.0, requires_grad=True))\n", | |
" trend = a * x + b\n", | |
" residual = y - trend\n", | |
" residual_model.set_data(x, residual)\n", | |
" residual_model.model()\n", | |
" \n", | |
"def guide(x, y):\n", | |
" residual_model.guide()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 5, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"batches_X = X.split(5)\n", | |
"batches_Y = Y.split(5)\n", | |
"optim = pyro.optim.Adam({\"lr\": 0.1})\n", | |
"svi = pyro.infer.SVI(model, guide, optim, \"ELBO\")\n", | |
"\n", | |
"pyro.clear_param_store()\n", | |
"for i in range(1000):\n", | |
" for (X_i, Y_i) in zip(batches_X, batches_Y):\n", | |
" svi.step(X_i, Y_i)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 6, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"\n", | |
" 2.0091\n", | |
"[torch.FloatTensor of size ()]" | |
] | |
}, | |
"execution_count": 6, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"pyro.param(\"a\")" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 7, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"\n", | |
" 8.1749\n", | |
"[torch.FloatTensor of size ()]" | |
] | |
}, | |
"execution_count": 7, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"pyro.param(\"b\")" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 8, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"\n", | |
" 0.9715\n", | |
"[torch.FloatTensor of size (1,)]" | |
] | |
}, | |
"execution_count": 8, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"kernel.get_param(\"period\") # expected 5 / 2*pi" | |
] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python (pyro)", | |
"language": "python", | |
"name": "pyro" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.5.4" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 2 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment