Skip to content

Instantly share code, notes, and snippets.

@NTT123
Last active October 11, 2018 05:41
Show Gist options
  • Save NTT123/ec8a3a531e0765ae5d75276bc184cb91 to your computer and use it in GitHub Desktop.
Save NTT123/ec8a3a531e0765ae5d75276bc184cb91 to your computer and use it in GitHub Desktop.
HMM.ipynb
Display the source blob
Display the rendered blob
Raw
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"name": "HMM.ipynb",
"version": "0.3.2",
"provenance": [],
"collapsed_sections": [],
"include_colab_link": true
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
}
},
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
},
"source": [
"[View in Colaboratory](https://colab.research.google.com/gist/NTT123/ec8a3a531e0765ae5d75276bc184cb91/hmm.ipynb)"
]
},
{
"metadata": {
"id": "qZafBPDoEFA1",
"colab_type": "text"
},
"cell_type": "markdown",
"source": [
"# Hidden Markov Models\n",
"\nAuthor: Thong Nguyen\n",
"\n",
"From \"A Tutorial on Hidden Markov Models and Selected Applications in Speech Recognition\" LAWRENCE R. RABINER, FELLOW, IEEE \n"
]
},
{
"metadata": {
"id": "qXeN7kfCEBLw",
"colab_type": "text"
},
"cell_type": "markdown",
"source": [
"## Markov chains"
]
},
{
"metadata": {
"id": "weA52KSaEpIz",
"colab_type": "text"
},
"cell_type": "markdown",
"source": [
"<img width=\"400px\" src=\"\" alt=\"\" />"
]
},
{
"metadata": {
"id": "k4xUOVM6E8PD",
"colab_type": "text"
},
"cell_type": "markdown",
"source": [
"Let $\\pi(t)$ is the distribution of states at time step $t$, we have:\n",
"$$\n",
"\\pi(t+1) = A^T \\pi(t)\n",
"$$"
]
},
{
"metadata": {
"id": "itJNf1WWkgNT",
"colab_type": "code",
"colab": {},
"cellView": "form"
},
"cell_type": "code",
"source": [
"#@title\n",
"# http://pytorch.org/\n",
"from os import path\n",
"from wheel.pep425tags import get_abbr_impl, get_impl_ver, get_abi_tag\n",
"platform = '{}{}-{}'.format(get_abbr_impl(), get_impl_ver(), get_abi_tag())\n",
"\n",
"accelerator = 'cu80' if path.exists('/opt/bin/nvidia-smi') else 'cpu'\n",
"\n",
"!pip install -q http://download.pytorch.org/whl/{accelerator}/torch-0.4.1-{platform}-linux_x86_64.whl torchvision\n",
" \n",
"data = \"\"\"2013-01-06,7.794975\n",
"2013-01-13,7.863400\n",
"2013-01-20,8.234920\n",
"2013-01-27,8.186260\n",
"2013-02-03,8.317480\n",
"2013-02-10,8.154040\n",
"2013-02-17,7.883740\n",
"2013-02-24,7.859420\n",
"2013-03-03,7.958560\n",
"2013-03-10,7.976260\n",
"2013-03-17,8.145660\n",
"2013-03-24,8.306540\n",
"2013-03-31,8.248350\n",
"2013-04-07,7.270060\n",
"2013-04-14,7.384700\n",
"2013-04-21,7.469860\n",
"2013-04-28,7.311920\n",
"2013-05-05,7.735940\n",
"2013-05-12,7.534840\n",
"2013-05-19,7.678740\n",
"2013-05-26,7.703500\n",
"2013-06-02,7.800800\n",
"2013-06-09,7.804360\n",
"2013-06-16,7.689940\n",
"2013-06-23,7.919940\n",
"2013-06-30,7.535260\n",
"2013-07-07,7.011460\n",
"2013-07-14,7.207520\n",
"2013-07-21,7.028780\n",
"2013-07-28,6.757800\n",
"2013-08-04,6.589520\n",
"2013-08-11,6.398180\n",
"2013-08-18,6.435180\n",
"2013-08-25,6.641880\n",
"2013-09-01,6.791300\n",
"2013-09-08,6.584620\n",
"2013-09-15,6.511300\n",
"2013-09-22,6.366740\n",
"2013-09-29,6.333920\n",
"2013-10-06,6.154460\n",
"2013-10-13,6.166320\n",
"2013-10-20,6.167020\n",
"2013-10-27,6.159360\n",
"2013-11-03,6.001520\n",
"2013-11-10,5.923280\n",
"2013-11-17,5.992420\n",
"2013-11-24,5.845100\n",
"2013-12-01,5.869520\n",
"2013-12-08,5.945120\n",
"2013-12-15,5.980200\n",
"2013-12-22,5.888720\n",
"2013-12-29,5.926050\n",
"2014-01-05,5.810850\n",
"2014-01-12,5.813080\n",
"2014-01-19,5.897680\n",
"2014-01-26,5.868780\n",
"2014-02-02,5.938920\n",
"2014-02-09,6.076480\n",
"2014-02-16,6.082700\n",
"2014-02-23,6.211280\n",
"2014-03-02,6.228440\n",
"2014-03-09,6.551060\n",
"2014-03-16,6.559180\n",
"2014-03-23,6.530720\n",
"2014-03-30,6.629660\n",
"2014-04-06,6.795720\n",
"2014-04-13,6.795040\n",
"2014-04-20,6.775225\n",
"2014-04-27,6.766620\n",
"2014-05-04,6.864440\n",
"2014-05-11,6.870480\n",
"2014-05-18,6.607120\n",
"2014-05-25,6.374580\n",
"2014-06-01,6.312940\n",
"2014-06-08,6.130040\n",
"2014-06-15,5.978580\n",
"2014-06-22,5.962500\n",
"2014-06-29,5.949760\n",
"2014-07-06,5.604380\n",
"2014-07-13,5.314360\n",
"2014-07-20,5.142040\n",
"2014-07-27,4.955360\n",
"2014-08-03,4.946680\n",
"2014-08-10,4.938600\n",
"2014-08-17,4.972100\n",
"2014-08-24,4.956000\n",
"2014-08-31,4.904480\n",
"2014-09-07,4.776020\n",
"2014-09-14,4.598040\n",
"2014-09-21,4.545860\n",
"2014-09-28,4.374580\n",
"2014-10-05,4.319680\n",
"2014-10-12,4.541180\n",
"2014-10-19,4.689040\n",
"2014-10-26,4.735860\n",
"2014-11-02,4.958700\n",
"2014-11-09,4.939940\n",
"2014-11-16,5.055720\n",
"2014-11-23,4.970060\n",
"2014-11-30,5.018660\n",
"2014-12-07,5.015720\n",
"2014-12-14,5.130840\n",
"2014-12-21,5.291280\n",
"2014-12-28,5.330100\n",
"2015-01-04,5.213650\n",
"2015-01-11,5.181300\n",
"2015-01-18,5.009220\n",
"2015-01-25,5.006000\n",
"2015-02-01,4.870800\n",
"2015-02-08,4.940660\n",
"2015-02-15,5.009900\n",
"2015-02-22,5.010540\n",
"2015-03-01,4.906280\n",
"2015-03-08,4.922140\n",
"2015-03-15,4.906320\n",
"2015-03-22,4.772120\n",
"2015-03-29,4.964580\n",
"2015-04-05,4.871475\n",
"2015-04-12,4.817700\n",
"2015-04-19,4.749300\n",
"2015-04-26,4.706060\n",
"2015-05-03,4.545440\n",
"2015-05-10,4.517460\n",
"2015-05-17,4.520540\n",
"2015-05-24,4.516200\n",
"2015-05-31,4.402940\n",
"2015-06-07,4.465800\n",
"2015-06-14,4.472640\n",
"2015-06-21,4.411040\n",
"2015-06-28,4.623160\n",
"2015-07-05,5.142840\n",
"2015-07-12,5.255180\n",
"2015-07-19,5.296720\n",
"2015-07-26,4.961260\n",
"2015-08-02,4.601720\n",
"2015-08-09,4.579440\n",
"2015-08-16,4.593880\n",
"2015-08-23,4.545800\n",
"2015-08-30,4.525360\n",
"2015-09-06,4.411720\n",
"2015-09-13,4.473660\n",
"2015-09-20,4.633600\n",
"2015-09-27,4.613740\n",
"2015-10-04,4.671460\n",
"2015-10-11,4.713540\n",
"2015-10-18,4.559620\n",
"2015-10-25,4.540980\n",
"2015-11-01,4.576460\n",
"2015-11-08,4.528960\n",
"2015-11-15,4.352200\n",
"2015-11-22,4.355200\n",
"2015-11-29,4.382240\n",
"2015-12-06,4.426340\n",
"2015-12-13,4.431680\n",
"2015-12-20,4.433440\n",
"2015-12-27,4.335400\n",
"2016-01-03,4.254175\n",
"2016-01-10,4.179460\n",
"2016-01-17,4.220240\n",
"2016-01-24,4.338940\n",
"2016-01-31,4.356660\n",
"2016-02-07,4.369680\n",
"2016-02-14,4.262160\n",
"2016-02-21,4.302920\n",
"2016-02-28,4.266340\n",
"2016-03-06,4.176820\n",
"2016-03-13,4.222440\n",
"2016-03-20,4.303180\n",
"2016-03-27,4.322175\n",
"2016-04-03,4.248160\n",
"2016-04-10,4.194940\n",
"2016-04-17,4.318960\n",
"2016-04-24,4.480920\n",
"2016-05-01,4.478360\n",
"2016-05-08,4.386540\n",
"2016-05-15,4.402120\n",
"2016-05-22,4.557520\n",
"2016-05-29,4.670100\n",
"2016-06-05,4.770040\n",
"2016-06-12,4.932320\n",
"2016-06-19,4.986040\n",
"2016-06-26,4.581180\n",
"2016-07-03,4.304040\n",
"2016-07-10,3.994300\n",
"2016-07-17,4.047380\n",
"2016-07-24,3.897380\n",
"2016-07-31,3.820160\n",
"2016-08-07,3.735640\n",
"2016-08-14,3.731200\n",
"2016-08-21,3.811200\n",
"2016-08-28,3.749080\n",
"2016-09-04,3.591260\n",
"2016-09-11,3.732840\n",
"2016-09-18,3.733980\n",
"2016-09-25,3.777040\n",
"2016-10-02,3.707120\n",
"2016-10-09,3.855420\n",
"2016-10-16,3.870520\n",
"2016-10-23,3.960620\n",
"2016-10-30,3.953340\n",
"2016-11-06,3.909120\n",
"2016-11-13,3.859900\n",
"2016-11-20,3.820720\n",
"2016-11-27,3.920300\n",
"2016-12-04,3.813300\n",
"2016-12-11,3.909460\n",
"2016-12-18,3.927520\n",
"2016-12-25,3.816060\n",
"2017-01-01,3.836375\n",
"2017-01-08,3.922450\n",
"2017-01-15,3.913840\n",
"2017-01-22,3.988140\n",
"2017-01-29,3.990860\n",
"2017-02-05,3.970100\n",
"2017-02-12,4.031860\n",
"2017-02-19,4.087060\n",
"2017-02-26,4.017940\n",
"2017-03-05,4.034880\n",
"2017-03-12,3.982420\n",
"2017-03-19,3.899360\n",
"2017-03-26,3.845300\n",
"2017-04-02,3.842600\n",
"2017-04-09,3.889200\n",
"2017-04-16,3.941000\n",
"2017-04-23,3.861580\n",
"2017-04-30,3.859440\n",
"2017-05-07,3.913580\n",
"2017-05-14,3.881540\n",
"2017-05-21,3.875240\n",
"2017-05-28,3.906200\n",
"2017-06-04,3.897800\n",
"2017-06-11,4.011300\n",
"2017-06-18,3.987140\n",
"2017-06-25,3.858920\n",
"2017-07-02,3.805340\n",
"2017-07-09,4.003400\n",
"2017-07-16,4.002000\n",
"2017-07-23,3.949500\n",
"2017-07-30,3.870500\n",
"2017-08-06,3.803000\n",
"2017-08-13,3.806500\n",
"2017-08-20,3.681500\n",
"2017-08-27,3.574500\n",
"2017-09-03,3.512500\n",
"2017-09-10,3.569000\n",
"2017-09-17,3.542500\n",
"2017-09-24,3.507000\n",
"2017-10-01,3.530000\n",
"\"\"\""
],
"execution_count": 0,
"outputs": []
},
{
"metadata": {
"id": "3LHMxCTCD4Ff",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 119
},
"outputId": "34c5081c-f202-42d7-b2e8-e964e32728a0"
},
"cell_type": "code",
"source": [
"\n",
"import torch\n",
"\n",
"A = torch.Tensor( [ [0.4, 0.3, 0.3],[0.2, 0.6, 0.2],[0.1, 0.1, 0.8]] )\n",
"print(A)\n",
"\n",
"pi = torch.rand(3, 1)\n",
"pi = pi / pi.sum()\n",
"\n",
"for _ in range(1000):\n",
" pi = torch.matmul(A.t(), pi)\n",
"print(pi)"
],
"execution_count": 0,
"outputs": [
{
"output_type": "stream",
"text": [
"tensor([[0.4000, 0.3000, 0.3000],\n",
" [0.2000, 0.6000, 0.2000],\n",
" [0.1000, 0.1000, 0.8000]])\n",
"tensor([[0.1818],\n",
" [0.2727],\n",
" [0.5455]])\n"
],
"name": "stdout"
}
]
},
{
"metadata": {
"id": "LNtzvrnJIub4",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 68
},
"outputId": "d7f8fab5-6217-4e2d-da59-cbb94e5277b5"
},
"cell_type": "code",
"source": [
"v = torch.eig(A.t(), eigenvectors=True)[1][:,0:1]\n",
"v=v/v.sum()\n",
"v"
],
"execution_count": 0,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"tensor([[0.1818],\n",
" [0.2727],\n",
" [0.5455]])"
]
},
"metadata": {
"tags": []
},
"execution_count": 3
}
]
},
{
"metadata": {
"id": "IuXrE48oOiI2",
"colab_type": "text"
},
"cell_type": "markdown",
"source": [
"## Hidden markov models\n",
"\n"
]
},
{
"metadata": {
"id": "ydXN26rzOmzb",
"colab_type": "text"
},
"cell_type": "markdown",
"source": [
"<img alt=\"\" src=\"//upload.wikimedia.org/wikipedia/commons/thumb/8/8a/HiddenMarkovModel.svg/300px-HiddenMarkovModel.svg.png\" width=\"300\" height=\"240\" class=\"thumbimage\" srcset=\"//upload.wikimedia.org/wikipedia/commons/thumb/8/8a/HiddenMarkovModel.svg/450px-HiddenMarkovModel.svg.png 1.5x, //upload.wikimedia.org/wikipedia/commons/thumb/8/8a/HiddenMarkovModel.svg/600px-HiddenMarkovModel.svg.png 2x\" data-file-width=\"750\" data-file-height=\"600\">"
]
},
{
"metadata": {
"id": "t7vbL_qXO_Kt",
"colab_type": "text"
},
"cell_type": "markdown",
"source": [
"We aim to compute\n",
"$$\n",
"P(y_{1:t})\n",
"$$"
]
},
{
"metadata": {
"id": "saHu9FtMPvc0",
"colab_type": "text"
},
"cell_type": "markdown",
"source": [
"Here we implement the forward algorithm. Let $\\alpha_t(x_t) = P( x_t, y_{1:t})$. \n",
"\n",
"We have\n",
"$$\n",
"P(y_{1:t}) = \\sum_{s \\in S} \\alpha_t (s)\n",
"$$ where $S$ is the set of hidden states.\n",
"\n",
"The recursive formula of $\\alpha_t(x_t)$ is as follows:\n",
"$$\n",
"\\alpha_t(x_t) = b_{x_t, y_t} \\times \\sum_{x_{t-1} \\in S} \\alpha_{t-1}(x_{t-1}) \\times a_{x_{t-1}, x_t} \n",
"$$\n",
"\n",
"and\n",
"\n",
"$$\n",
"\\alpha_1(x_1) = \\pi(x_1) \\times b_{x_1, y_1}\n",
"$$"
]
},
{
"metadata": {
"id": "bl7fOQGALlEk",
"colab_type": "code",
"colab": {}
},
"cell_type": "code",
"source": [
"import math\n",
"\n",
"class HMM(torch.nn.Module):\n",
" def __init__(self, n):\n",
" super().__init__()\n",
" \n",
" self.n = n\n",
" A = 0.01*torch.randn(n, n)\n",
" self.register_parameter(\"A\", torch.nn.Parameter(A))\n",
" B_mean = torch.randn(n)\n",
" B_logvar = 0.01*torch.randn(n)\n",
" self.register_parameter(\"B_mean\", torch.nn.Parameter(B_mean))\n",
" self.register_parameter(\"B_logvar\", torch.nn.Parameter(B_logvar))\n",
" pi = 0.01*torch.randn(n)\n",
" self.register_parameter(\"pi\", torch.nn.Parameter(pi))\n",
" \n",
"\n",
" def forward(self, y):\n",
" #y :observed values y1, ...\n",
" \n",
" # convert to probabilities\n",
" pa = torch.log_softmax(self.A, dim=1 )\n",
" ppi = torch.log_softmax(self.pi, dim=0)\n",
" \n",
" # and variance\n",
" var = torch.exp(self.B_logvar) \n",
" \n",
" # gaussian observed values\n",
" py = -torch.pow(self.B_mean - y[0],2) / var/2 - 0.5 * (math.log(2*math.pi)+self.B_logvar)\n",
" \n",
" # initial alpha values (log likelihood)\n",
" alpha = py + ppi\n",
" \n",
" # recurvive computation (on log scale)\n",
" for i in range(1, y.size(0)):\n",
" py = -torch.pow(self.B_mean - y[i],2) / var/2 - 0.5 * (math.log(2*math.pi)+self.B_logvar)\n",
" alpha = torch.logsumexp(alpha.view(self.n, 1) + pa, dim=0) + py\n",
" \n",
" # return sum of alpha \n",
" return torch.logsumexp(alpha, dim=0)"
],
"execution_count": 0,
"outputs": []
},
{
"metadata": {
"id": "3336AiR7LyiV",
"colab_type": "code",
"colab": {}
},
"cell_type": "code",
"source": [
"model = HMM(5)"
],
"execution_count": 0,
"outputs": []
},
{
"metadata": {
"id": "489YhrFrzAFG",
"colab_type": "code",
"colab": {}
},
"cell_type": "code",
"source": [
"optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)"
],
"execution_count": 0,
"outputs": []
},
{
"metadata": {
"id": "qa1qu2TrL4MT",
"colab_type": "code",
"colab": {}
},
"cell_type": "code",
"source": [
"price = torch.Tensor([float(l.split(\",\")[1]) for l in data.split(\"\\n\")[:-1]])"
],
"execution_count": 0,
"outputs": []
},
{
"metadata": {
"id": "oZpMF-hQ9SG4",
"colab_type": "code",
"colab": {}
},
"cell_type": "code",
"source": [
"price = price - price.mean()\n",
"price = price / price.std()"
],
"execution_count": 0,
"outputs": []
},
{
"metadata": {
"id": "x9AFoIffuULY",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 34
},
"outputId": "a27a0fbe-e8da-47ea-ca85-de362128e6e5"
},
"cell_type": "code",
"source": [
"price.size(0)-5"
],
"execution_count": 0,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"243"
]
},
"metadata": {
"tags": []
},
"execution_count": 9
}
]
},
{
"metadata": {
"id": "vs1lvUgSzF6i",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 2567
},
"outputId": "907bda94-d57d-4dc0-b583-ccecac85056c"
},
"cell_type": "code",
"source": [
"ll = -1\n",
"for epoch in range(500000): \n",
" idx = epoch % (price.size(0) - 10)\n",
" loss= - model( price[idx:idx+5])\n",
" optimizer.zero_grad()\n",
"\n",
" loss.backward()\n",
"\n",
" optimizer.step()\n",
" if ll == -1:\n",
" ll = loss.data.item()\n",
" else:\n",
" ll = 0.999 * ll + (1.0 - 0.999) * loss.data.item()\n",
" if epoch % 10000 == 0:\n",
" print(ll)\n",
" print( model( price[-5:]).item())\n",
" print(model.B_mean.data)"
],
"execution_count": 0,
"outputs": [
{
"output_type": "stream",
"text": [
"10.42137336730957\n",
"-9.904097557067871\n",
"tensor([ 1.2435, -0.0636, 1.4191, 0.5233, -0.8998])\n",
"1.6659804438897716\n",
"-5.045734405517578\n",
"tensor([ 1.5446, -0.9213, 1.5701, 0.6190, -0.3473])\n",
"0.26294181726871196\n",
"-24.104997634887695\n",
"tensor([ 2.0082, -0.9540, 2.0505, 0.8773, -0.3952])\n",
"0.14269235375517736\n",
"-24.943679809570312\n",
"tensor([ 1.8318, -0.9539, 2.2615, 0.8765, -0.3955])\n",
"0.10994652790125825\n",
"-25.016605377197266\n",
"tensor([ 1.8448, -0.9536, 2.2792, 0.8768, -0.3955])\n",
"0.11389683388326427\n",
"-25.063264846801758\n",
"tensor([ 1.8473, -0.9534, 2.2802, 0.8769, -0.3955])\n",
"0.1267401776018149\n",
"-25.094295501708984\n",
"tensor([ 1.8523, -0.9533, 2.2803, 0.8769, -0.3955])\n",
"0.15214649965553814\n",
"-25.11527442932129\n",
"tensor([ 1.8511, -0.9532, 2.2802, 0.8789, -0.3955])\n",
"0.1590621973444389\n",
"-25.12967300415039\n",
"tensor([ 1.8503, -0.9532, 2.2801, 0.8832, -0.3955])\n",
"0.16278601880001342\n",
"-25.139789581298828\n",
"tensor([ 1.8497, -0.9531, 2.2801, 0.8881, -0.3955])\n",
"0.16419796758234295\n",
"-25.14704704284668\n",
"tensor([ 1.8494, -0.9531, 2.2800, 0.8910, -0.3955])\n",
"0.1716214467047759\n",
"-25.152423858642578\n",
"tensor([ 1.8491, -0.9531, 2.2800, 0.8906, -0.3955])\n",
"0.18095696763137814\n",
"-25.156557083129883\n",
"tensor([ 1.8490, -0.9531, 2.2800, 0.8872, -0.3955])\n",
"0.19245426378362698\n",
"-25.159931182861328\n",
"tensor([ 1.8489, -0.9531, 2.2800, 0.8815, -0.3955])\n",
"0.1977639607957912\n",
"-25.162765502929688\n",
"tensor([ 1.8488, -0.9531, 2.2800, 0.8749, -0.3955])\n",
"0.20026169399188348\n",
"-25.16525650024414\n",
"tensor([ 1.8488, -0.9531, 2.2800, 0.8703, -0.3955])\n",
"0.21241968293279842\n",
"-25.16754150390625\n",
"tensor([ 1.8488, -0.9531, 2.2800, 0.8700, -0.3955])\n",
"0.23024657079536856\n",
"-25.16966438293457\n",
"tensor([ 1.8487, -0.9531, 2.2800, 0.8740, -0.3955])\n",
"0.23538370344566628\n",
"-25.171716690063477\n",
"tensor([ 1.8487, -0.9531, 2.2800, 0.8789, -0.3955])\n",
"0.25342020583923763\n",
"-25.173795700073242\n",
"tensor([ 1.8487, -0.9531, 2.2800, 0.8801, -0.3954])\n",
"0.27578638229087593\n",
"-25.180063247680664\n",
"tensor([ 1.8487, -0.9531, 2.2800, 0.8790, -0.3933])\n",
"0.2771308563839316\n",
"-25.192819595336914\n",
"tensor([ 1.8487, -0.9531, 2.2800, 0.8783, -0.3894])\n",
"0.27364688727594383\n",
"-25.204593658447266\n",
"tensor([ 1.8487, -0.9531, 2.2800, 0.8778, -0.3861])\n",
"0.270464264793765\n",
"-25.212421417236328\n",
"tensor([ 1.8487, -0.9531, 2.2800, 0.8775, -0.3847])\n",
"0.27361799540998893\n",
"-25.221965789794922\n",
"tensor([ 1.8487, -0.9531, 2.2800, 0.8773, -0.3821])\n",
"0.28992920987026605\n",
"-25.2320499420166\n",
"tensor([ 1.8487, -0.9531, 2.2800, 0.8771, -0.3774])\n",
"0.30947759894336596\n",
"-25.240856170654297\n",
"tensor([ 1.8487, -0.9531, 2.2800, 0.8770, -0.3706])\n",
"0.31196499690105706\n",
"-25.24869155883789\n",
"tensor([ 1.8487, -0.9531, 2.2800, 0.8770, -0.3635])\n",
"0.3114140619238951\n",
"-25.255496978759766\n",
"tensor([ 1.8487, -0.9531, 2.2800, 0.8769, -0.3568])\n",
"0.30763316274840796\n",
"-25.26129913330078\n",
"tensor([ 1.8487, -0.9531, 2.2800, 0.8769, -0.3510])\n",
"0.30192596444441794\n",
"-25.26610565185547\n",
"tensor([ 1.8487, -0.9531, 2.2800, 0.8769, -0.3469])\n",
"0.2991548043879867\n",
"-25.26995277404785\n",
"tensor([ 1.8487, -0.9531, 2.2800, 0.8769, -0.3455])\n",
"0.30287387135995864\n",
"-25.273332595825195\n",
"tensor([ 1.8487, -0.9531, 2.2800, 0.8769, -0.3453])\n",
"0.3109499857409771\n",
"-25.2768611907959\n",
"tensor([ 1.8487, -0.9531, 2.2800, 0.8769, -0.3433])\n",
"0.30617735598520895\n",
"-25.28030014038086\n",
"tensor([ 1.8487, -0.9531, 2.2800, 0.8769, -0.3421])\n",
"0.30140305379253257\n",
"-25.28352165222168\n",
"tensor([ 1.8487, -0.9531, 2.2800, 0.8769, -0.3424])\n",
"0.2949640793011718\n",
"-25.286760330200195\n",
"tensor([ 1.8487, -0.9531, 2.2800, 0.8769, -0.3432])\n",
"0.29375283307946354\n",
"-25.28981590270996\n",
"tensor([ 1.8487, -0.9531, 2.2800, 0.8769, -0.3450])\n",
"0.2951011015086103\n",
"-25.29248046875\n",
"tensor([ 1.8487, -0.9531, 2.2800, 0.8769, -0.3483])\n",
"0.3039233211118434\n",
"-25.294673919677734\n",
"tensor([ 1.8487, -0.9531, 2.2800, 0.8769, -0.3530])\n",
"0.31003409113126457\n",
"-25.2957706451416\n",
"tensor([ 1.8487, -0.9531, 2.2800, 0.8769, -0.3589])\n",
"0.3194086324292594\n",
"-25.29519271850586\n",
"tensor([ 1.8487, -0.9531, 2.2800, 0.8769, -0.3655])\n",
"0.3268512756317831\n",
"-25.291227340698242\n",
"tensor([ 1.8487, -0.9531, 2.2800, 0.8769, -0.3727])\n",
"0.3265047370139794\n",
"-25.28326416015625\n",
"tensor([ 1.8487, -0.9531, 2.2800, 0.8769, -0.3796])\n",
"0.32223086287632036\n",
"-25.274822235107422\n",
"tensor([ 1.8487, -0.9531, 2.2800, 0.8769, -0.3849])\n",
"0.32529371419852443\n",
"-25.279953002929688\n",
"tensor([ 1.8487, -0.9530, 2.2800, 0.8769, -0.3880])\n",
"0.3211423304196827\n",
"-25.399688720703125\n",
"tensor([ 1.8487, -0.9523, 2.2800, 0.8769, -0.3905])\n",
"0.31910119447616425\n",
"-24.66801643371582\n",
"tensor([ 1.8487, -0.9554, 2.2800, 0.8769, -0.3922])\n",
"0.32236833639177226\n",
"-22.78359031677246\n",
"tensor([ 1.8487, -0.9633, 2.2800, 0.8769, -0.3933])\n",
"0.30346350158836055\n",
"-20.935359954833984\n",
"tensor([ 1.8487, -0.9713, 2.2800, 0.8769, -0.3941])\n"
],
"name": "stdout"
}
]
},
{
"metadata": {
"id": "8USfK1A5zz9r",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 34
},
"outputId": "67426f52-6e69-47a5-d106-de8ba772a6ca"
},
"cell_type": "code",
"source": [
"torch.exp(0.5*model.B_logvar)"
],
"execution_count": 0,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"tensor([0.2090, 0.0814, 0.1326, 0.2928, 0.2456], grad_fn=<ExpBackward>)"
]
},
"metadata": {
"tags": []
},
"execution_count": 11
}
]
},
{
"metadata": {
"id": "16RKQkLVBFab",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 71
},
"outputId": "9bd14b22-14fa-4d2d-b1cf-602dc965f4cb"
},
"cell_type": "code",
"source": [
"model.B_mean"
],
"execution_count": 0,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"Parameter containing:\n",
"tensor([ 1.8487, -0.9752, 2.2800, 0.8769, -0.3944], requires_grad=True)"
]
},
"metadata": {
"tags": []
},
"execution_count": 12
}
]
},
{
"metadata": {
"id": "pNYfdnwA3TG_",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 122
},
"outputId": "2be86403-2452-4f0f-de29-ba320a404671"
},
"cell_type": "code",
"source": [
"torch.round(1000*torch.softmax(model.A, dim=1 ))/1000"
],
"execution_count": 0,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000],\n",
" [0.0000, 1.0000, 0.0000, 0.0000, 0.0000],\n",
" [0.0970, 0.0000, 0.9030, 0.0000, 0.0000],\n",
" [0.0000, 0.0000, 0.0000, 0.9930, 0.0070],\n",
" [0.0000, 0.0050, 0.0000, 0.0000, 0.9950]], grad_fn=<DivBackward0>)"
]
},
"metadata": {
"tags": []
},
"execution_count": 13
}
]
},
{
"metadata": {
"id": "sznkRkJb4IVr",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 54
},
"outputId": "922cfbc3-e623-4783-a22d-de4586dc27ef"
},
"cell_type": "code",
"source": [
"torch.softmax(model.pi, dim=0)"
],
"execution_count": 0,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"tensor([0.0515, 0.2224, 0.0535, 0.2139, 0.4587], grad_fn=<SoftmaxBackward>)"
]
},
"metadata": {
"tags": []
},
"execution_count": 14
}
]
},
{
"metadata": {
"id": "f4UF1QKi3yR7",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 452
},
"outputId": "7e0b3c4a-26db-4b78-e7c3-5d70038c8339"
},
"cell_type": "code",
"source": [
"import matplotlib.pyplot as plt\n",
"\n",
"plt.hist(price.numpy())"
],
"execution_count": 0,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"(array([54., 43., 47., 25., 8., 23., 19., 5., 11., 13.]),\n",
" array([-1.25947654, -0.88813452, -0.51679249, -0.14545046, 0.22589157,\n",
" 0.59723359, 0.96857562, 1.33991765, 1.71125968, 2.0826017 ,\n",
" 2.45394373]),\n",
" <a list of 10 Patch objects>)"
]
},
"metadata": {
"tags": []
},
"execution_count": 15
},
{
"output_type": "display_data",
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAd8AAAFKCAYAAABcq1WoAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMS4yLCBo\ndHRwOi8vbWF0cGxvdGxpYi5vcmcvNQv5yAAAE6pJREFUeJzt3X1s3WX9//HX1qbOSoExW2CCaAwK\nAooLNzLSYTcgDhOdRmBrwCg3QrhxJNNRkQiJCY6Bi6CJQ2TTYIgNDSF8E2IXBBIMtSgYIsRk3CQE\nuZkdVhisdaw5vz9+sWGy9XRde7Xn7PH4qz2372tX4JnPp+3nzKpUKpUAAMXMnu4BAGB/I74AUJj4\nAkBh4gsAhYkvABQmvgBQWGOJNxkY2FbibabU3LnNGRzcPt1jTIl6XVu9riuxtlpUr+tKrG1PWltb\n9nifI99xamxsmO4Rpky9rq1e15VYWy2q13Ul1jYR4gsAhYkvABQmvgBQmPgCQGHiCwCFiS8AFCa+\nAFCY+AJAYeILAIWJLwAUJr4AUJj4AkBhRT7VaCpctObh6R5hTBu6Fk/3CADMUI58AaAw8QWAwsQX\nAAoTXwAoTHwBoDDxBYDCxBcAChNfAChMfAGgMPEFgMLEFwAKE18AKEx8AaAw8QWAwsQXAAoTXwAo\nTHwBoLDGag/o7+/PypUrc/TRRydJPvnJT+aSSy7J6tWrMzIyktbW1txyyy1pamqa8mEBoB5UjW+S\nnHLKKbn99ttHv//+97+fzs7OLF26NOvWrUtPT086OzunbEgAqCcTOu3c39+fJUuWJEk6OjrS19c3\nqUMBQD0b15Hv888/n8svvzxvvvlmrrrqqgwNDY2eZp43b14GBgbGfP7cuc1pbGzY92lrSGtry3SP\nsFdqbd7xqtd1JdZWi+p1XYm17a2q8f3Yxz6Wq666KkuXLs3LL7+cb3zjGxkZGRm9v1KpVH2TwcHt\n+zZlDRoY2DbdI4xba2tLTc07XvW6rsTaalG9riuxtrGeuydVTzsfeuihOeecczJr1qx89KMfzYc/\n/OG8+eabGR4eTpJs2bIlbW1tExoMAPZHVeP7wAMP5K677kqSDAwM5I033sjXvva19Pb2Jkk2bdqU\n9vb2qZ0SAOpI1dPOixcvzne/+9384Q9/yLvvvpsbb7wxxx57bK699tp0d3dn/vz5WbZsWYlZAaAu\nVI3vAQcckPXr17/v9o0bN07JQABQ71zhCgAKE18AKEx8AaAw8QWAwsQXAAoTXwAoTHwBoDDxBYDC\nxBcAChNfAChsXJ/nS326aM3D0z3CmDZ0LZ7uEQCmhCNfAChMfAGgMPEFgMLEFwAKE18AKEx8AaAw\n8QWAwsQXAAoTXwAoTHwBoDDxBYDCxBcAChNfAChMfAGgMPEFgMLEFwAKE18AKEx8AaAw8QWAwsQX\nAAprnO4B6tVFax6e7hEAmKEc+QJAYeILAIWJLwAUJr4AUJj4AkBh4gsAhYkvABQmvgBQmPgCQGHi\nCwCFiS8AFDau+A4PD+fMM8/Mfffdl9deey0XXnhhOjs7s3LlyuzYsWOqZwSAujKu+P7iF7/IQQcd\nlCS5/fbb09nZmXvuuSdHHXVUenp6pnRAAKg3VeP7wgsv5Pnnn88XvvCFJEl/f3+WLFmSJOno6Ehf\nX9+UDggA9aZqfG+++eZ0dXWNfj80NJSmpqYkybx58zIwMDB10wFAHRrz83zvv//+nHjiiTnyyCN3\ne3+lUhnXm8yd25zGxoa9n479Wmtry4x4jZnK2mpPva4rsba9NWZ8H3300bz88st59NFH8/rrr6ep\nqSnNzc0ZHh7OnDlzsmXLlrS1tVV9k8HB7ZM2MPuPgYFt+/T81taWfX6Nmcraak+9riuxtrGeuydj\nxvenP/3p6Nc/+9nP8pGPfCR//etf09vbm6985SvZtGlT2tvbJzQUAOyv9vrvfK+++urcf//96ezs\nzL///e8sW7ZsKuYCgLo15pHve1199dWjX2/cuHFKhgGA/YErXAFAYeILAIWJLwAUJr4AUJj4AkBh\n4gsAhYkvABQmvgBQmPgCQGHiCwCFiS8AFCa+AFCY+AJAYeILAIWJLwAUJr4AUJj4AkBh4gsAhYkv\nABQmvgBQmPgCQGHiCwCFiS8AFCa+AFCY+AJAYeILAIWJLwAUJr4AUJj4AkBh4gsAhYkvABQmvgBQ\nmPgCQGHiCwCFiS8AFCa+AFCY+AJAYeILAIWJLwAUJr4AUJj4AkBh4gsAhYkvABQmvgBQWGO1BwwN\nDaWrqytvvPFG/vOf/+SKK67IMccck9WrV2dkZCStra255ZZb0tTUVGJeAKh5VeP7yCOP5Pjjj8+l\nl16aV155JRdddFEWLFiQzs7OLF26NOvWrUtPT086OztLzAsANa/qaedzzjknl156aZLktddey6GH\nHpr+/v4sWbIkSdLR0ZG+vr6pnRIA6kjVI9//Wr58eV5//fWsX78+3/rWt0ZPM8+bNy8DAwNTNiAA\n1Jtxx/d3v/td/v73v+d73/teKpXK6O3v/XpP5s5tTmNjw8QmZL/V2toyI15jprK22lOv60qsbW9V\nje8zzzyTefPm5fDDD8+xxx6bkZGRfOhDH8rw8HDmzJmTLVu2pK2tbczXGBzcPmkDs/8YGNi2T89v\nbW3Z59eYqayt9tTruhJrG+u5e1L1Z75/+ctfsmHDhiTJ1q1bs3379ixcuDC9vb1Jkk2bNqW9vX1C\ngwHA/qjqke/y5cvzgx/8IJ2dnRkeHs4Pf/jDHH/88bn22mvT3d2d+fPnZ9myZSVmBYC6UDW+c+bM\nyU9+8pP33b5x48YpGQgA6p0rXAFAYeILAIWJLwAUJr4AUJj4AkBh4gsAhYkvABQmvgBQmPgCQGHi\nCwCFiS8AFCa+AFCY+AJAYeILAIWJLwAUJr4AUJj4AkBh4gsAhYkvABQmvgBQmPgCQGGN0z0A7MlF\nax6e7hGq2tC1eLpHAGqQI18AKEx8AaAw8QWAwsQXAAoTXwAoTHwBoDDxBYDCxBcAChNfAChMfAGg\nMJeXhDo30y/T6RKd7I8c+QJAYeILAIWJLwAUJr4AUJj4AkBh4gsAhYkvABQmvgBQmPgCQGHiCwCF\niS8AFDauazuvXbs2Tz75ZHbu3JnLLrssJ5xwQlavXp2RkZG0trbmlltuSVNT01TPCgB1oWp8//Sn\nP+W5555Ld3d3BgcH89WvfjWnnXZaOjs7s3Tp0qxbty49PT3p7OwsMS8A1Lyqp51PPvnk3HbbbUmS\nAw88MENDQ+nv78+SJUuSJB0dHenr65vaKQGgjlQ98m1oaEhzc3OSpKenJ4sWLcof//jH0dPM8+bN\ny8DAwJivMXducxobGyZhXJhZWltb9sv3nky7W0e9rO1/1eu6EmvbW+P+PN+HHnooPT092bBhQ84+\n++zR2yuVStXnDg5un9h0MMMNDGyblvdtbW2ZtveebP+7jnpa23vV67oSaxvruXsyrt92fuyxx7J+\n/frceeedaWlpSXNzc4aHh5MkW7ZsSVtb24QGA4D9UdX4btu2LWvXrs0dd9yRgw8+OEmycOHC9Pb2\nJkk2bdqU9vb2qZ0SAOpI1dPODz74YAYHB3PNNdeM3rZmzZpcf/316e7uzvz587Ns2bIpHRIA6knV\n+J5//vk5//zz33f7xo0bp2QgAKh34/6FK4CpcNGah6d7hKo2dC2e7hGoMy4vCQCFiS8AFCa+AFCY\n+AJAYeILAIWJLwAUJr4AUJj4AkBh4gsAhYkvABQmvgBQmPgCQGHiCwCFiS8AFCa+AFCY+AJAYeIL\nAIWJLwAUJr4AUJj4AkBh4gsAhYkvABQmvgBQmPgCQGHiCwCFiS8AFCa+AFCY+AJAYeILAIWJLwAU\n1jjdAwBQ/y5a8/B0j1DVhq7Fxd7LkS8AFCa+AFCY+AJAYeILAIWJLwAUJr4AUJj4AkBh4gsAhYkv\nABQmvgBQmMtLAtS4Wrh0I7sa15Hv5s2bc+aZZ+a3v/1tkuS1117LhRdemM7OzqxcuTI7duyY0iEB\noJ5Uje/27dvzox/9KKeddtrobbfffns6Oztzzz335KijjkpPT8+UDgkA9aRqfJuamnLnnXemra1t\n9Lb+/v4sWbIkSdLR0ZG+vr6pmxAA6kzVn/k2NjamsXHXhw0NDaWpqSlJMm/evAwMDEzNdABQh/b5\nF64qlUrVx8yd25zGxoZ9fSuYcVpbW/bL997fTNa/tT2b2fa0P1OxbxOKb3Nzc4aHhzNnzpxs2bJl\nl1PSuzM4uH1Cw8FMNzCwbVret7W1Zdree380Gf/W9mzm293+7Mu+jRXtCf2d78KFC9Pb25sk2bRp\nU9rb2yc0GADsj6oe+T7zzDO5+eab88orr6SxsTG9vb259dZb09XVle7u7syfPz/Lli0rMSsA1IWq\n8T3++ONz9913v+/2jRs3TslAAFDvXF4SAAoTXwAoTHwBoDDxBYDCxBcAChNfACjM5/nCPvA5qsBE\nOPIFgMLEFwAKE18AKEx8AaAw8QWAwsQXAAoTXwAoTHwBoDDxBYDCxBcAChNfAChMfAGgMPEFgMLE\nFwAK85GCAFX46EgmmyNfAChMfAGgMPEFgMLEFwAKE18AKEx8AaAw8QWAwsQXAAoTXwAoTHwBoDDx\nBYDCxBcAChNfAChMfAGgMPEFgMLEFwAKE18AKEx8AaAw8QWAwsQXAAoTXwAoTHwBoLDGiT7xpptu\nytNPP51Zs2bluuuuy2c+85nJnAsA6taE4vvEE0/kpZdeSnd3d1544YVcd9116e7unuzZAKAuTei0\nc19fX84888wkySc+8Ym8+eabefvttyd1MACoVxOK79atWzN37tzR7w855JAMDAxM2lAAUM8m/DPf\n96pUKmPe39raMhlvs4v/+8lXJv01AeB/TUXDJnTk29bWlq1bt45+/89//jOtra2TNhQA1LMJxff0\n009Pb29vkuTZZ59NW1tbDjjggEkdDADq1YROOy9YsCDHHXdcli9fnlmzZuWGG26Y7LkAoG7NqlT7\ngS0AMKlc4QoAChNfAChsUv7UqF498cQTWblyZW666aZ0dHS87/4HHnggv/nNbzJ79uycd955Offc\nc6dhyr3z7rvvpqurK6+++moaGhry4x//OEceeeQujznuuOOyYMGC0e9//etfp6GhofSoe2Wsy50+\n/vjjWbduXRoaGrJo0aJceeWV0zjp3hlrXYsXL85hhx02uje33nprDj300Okada9t3rw5V1xxRb75\nzW/mggsu2OW+Wt6zZOy11fq+rV27Nk8++WR27tyZyy67LGefffbofbW8b2Ota0r2rMJuvfTSS5XL\nL7+8csUVV1Qefvjh993/zjvvVM4+++zKW2+9VRkaGqp86UtfqgwODk7DpHvnvvvuq9x4442VSqVS\neeyxxyorV65832NOOeWU0mPtk/7+/sq3v/3tSqVSqTz//POV8847b5f7ly5dWnn11VcrIyMjlRUr\nVlSee+656Rhzr1VbV0dHR+Xtt9+ejtH22TvvvFO54IILKtdff33l7rvvft/9tbpnlUr1tdXyvvX1\n9VUuueSSSqVSqfzrX/+qnHHGGbvcX6v7Vm1dU7FnTjvvQWtra37+85+npWX3f1z99NNP54QTTkhL\nS0vmzJmTBQsW5Kmnnio85d7r6+vLWWedlSRZuHBhTcxczViXO3355Zdz0EEH5fDDD8/s2bNzxhln\npK+vbzrHHbd6voxrU1NT7rzzzrS1tb3vvlres2TstdW6k08+ObfddluS5MADD8zQ0FBGRkaS1Pa+\njbWuqSK+e/DBD35wzFOtW7duzSGHHDL6fa1cYvO9c8+ePTuzZs3Kjh07dnnMjh07smrVqixfvjwb\nN26cjjH3yliXOx0YGKjJfUrGdxnXG264IStWrMitt95a9UpzM0ljY2PmzJmz2/tqec+Ssdf2X7W6\nbw0NDWlubk6S9PT0ZNGiRaP/n6zlfRtrXf812XvmZ75J7r333tx777273Hb11Venvb193K8xE/8D\n2t26nn766V2+393cq1evzpe//OXMmjUrF1xwQU466aSccMIJUzrrZJqJezEZ/ndd3/nOd9Le3p6D\nDjooV155ZXp7e/PFL35xmqZjvOph3x566KH09PRkw4YN0z3KpNrTuqZiz8Q3ybnnnrvXvyy1u0ts\nnnjiiZM92j7Z3bq6uroyMDCQY445Ju+++24qlUqampp2ecyKFStGv/785z+fzZs3z+j4jnW50/+9\nb8uWLTVzOrDaZVyXLVs2+vWiRYuyefPmmvuf+O7U8p6NR63v22OPPZb169fnV7/61S4/lqv1fdvT\nupKp2TOnnSfos5/9bP72t7/lrbfeyjvvvJOnnnoqJ5100nSPVdXpp5+e3//+90mSRx55JKeeeuou\n97/44otZtWpVKpVKdu7cmaeeeipHH330dIw6bmNd7vSII47I22+/nX/84x/ZuXNnHnnkkZx++unT\nOe64jbWubdu25eKLLx79kcGf//znGb9P41XLe1ZNre/btm3bsnbt2txxxx05+OCDd7mvlvdtrHVN\n1Z458t2DRx99NHfddVdefPHFPPvss7n77ruzYcOG/PKXv8zJJ5+cz33uc1m1alUuvvjizJo1K1de\neeUefzlrJjnnnHPy+OOPZ8WKFWlqasqaNWuSZJd1HXbYYfn617+e2bNnZ/Hixbv8ectMtLvLnd53\n331paWnJWWedlRtvvDGrVq1K8v/X//GPf3yaJx6fautatGhRzj///HzgAx/Ipz/96Zo6enrmmWdy\n880355VXXkljY2N6e3uzePHiHHHEETW9Z0n1tdXyvj344IMZHBzMNddcM3rbqaeemk996lM1vW/V\n1jUVe+bykgBQmNPOAFCY+AJAYeILAIWJLwAUJr4AUJj4AkBh4gsAhYkvABT2/wCHFGn97RpH2QAA\nAABJRU5ErkJggg==\n",
"text/plain": [
"<matplotlib.figure.Figure at 0x7f768c757080>"
]
},
"metadata": {
"tags": []
}
}
]
}
]
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment