Last active
February 6, 2018 12:38
-
-
Save vinupriyesh/f30fd64b85a6026685330909e16930d4 to your computer and use it in GitHub Desktop.
Basic RNN using pytorch
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"## Problem statement\n", | |
"\n", | |
"Same as https://gist.github.com/vinupriyesh/c764d26100e127d0d0f434b1c5b2cd51 but to design using RNN" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"import torch\n", | |
"from torch import nn\n", | |
"from torch.autograd import Variable" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"Our model this time will be using RNN, here h0 is attached to the next input node" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 2, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"class SimpleRNNModule(nn.Module):\n", | |
" def __init__(self):\n", | |
" super(SimpleRNNModule,self).__init__()\n", | |
" self.fc1 = nn.Linear(4,3)\n", | |
" self.fc2 = nn.Linear(3,3)\n", | |
" self.h0 = Variable(torch.zeros(1,3))\n", | |
" \n", | |
" def forward(self,x):\n", | |
" x_effective = torch.cat((x,self.h0),1)\n", | |
" z1 = self.fc1(x_effective)\n", | |
" a1 = torch.tanh(z1)\n", | |
" self.h0 = a1\n", | |
" z2 = self.fc2(a1)\n", | |
" a2 = torch.sigmoid(z2)\n", | |
" return a2 \n", | |
" \n", | |
" def reset(self):\n", | |
" self.h0 = Variable(torch.zeros(1,3))\n" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"Model is ready, next is input and output. Here the output is in a sequence so it will be different than the previous example" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 3, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"def generate_input():\n", | |
" x = torch.Tensor(1, 1).uniform_(0, 1)\n", | |
" return torch.bernoulli(x) \n", | |
"\n", | |
"\n", | |
"def generate_output(x):\n", | |
" y = torch.zeros(1,3)\n", | |
" sum_x = x.sum()\n", | |
" if sum_x == 0:\n", | |
" y[0,0] = 1\n", | |
" elif sum_x == 1:\n", | |
" y[0,1] = 1\n", | |
" else:\n", | |
" y[0,2] = 1\n", | |
" return y" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 4, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"\n", | |
" 1\n", | |
"[torch.FloatTensor of size 1x1]" | |
] | |
}, | |
"execution_count": 4, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"generate_input()" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"can create the model, the loss and the loop" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 10, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Loss on iteration 0 is 0.8868908286094666\n", | |
"Loss on iteration 100 is 0.6027839779853821\n", | |
"Loss on iteration 200 is 0.06013947352766991\n", | |
"Loss on iteration 300 is 0.14052194356918335\n", | |
"Loss on iteration 400 is 0.0651470199227333\n", | |
"Loss on iteration 500 is 0.02559574507176876\n", | |
"Loss on iteration 600 is 0.021827315911650658\n", | |
"Loss on iteration 700 is 0.004783849697560072\n", | |
"Loss on iteration 800 is 0.011788919568061829\n", | |
"Loss on iteration 900 is 0.004893178585916758\n" | |
] | |
} | |
], | |
"source": [ | |
"nn_model = SimpleRNNModule()\n", | |
"criterion = nn.BCELoss() #Binary Cross Entropy Loss\n", | |
"optimizer = torch.optim.Adam(nn_model.parameters(),lr = 0.1)\n", | |
"optimizer.zero_grad()\n", | |
"\n", | |
"x_prev = torch.zeros(1,1)\n", | |
"\n", | |
"for i in range(1000):\n", | |
" x = generate_input()\n", | |
" y = generate_output(torch.cat((x_prev,x),1))\n", | |
" x = Variable(x)\n", | |
" y = Variable(y)\n", | |
" y_hat = nn_model(x)\n", | |
" loss = criterion(y_hat,y)\n", | |
" optimizer.zero_grad()\n", | |
" loss.backward(retain_graph=True) \n", | |
" optimizer.step()\n", | |
" x_prev = x.data\n", | |
" if i % 100 == 0:\n", | |
" print(\"Loss on iteration {} is {}\".format(i,loss.data[0])) " | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"Check how did this perform" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 12, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"All done\n" | |
] | |
} | |
], | |
"source": [ | |
"def check_model():\n", | |
" nn_model.reset()\n", | |
" \n", | |
" x = Variable(torch.FloatTensor([0])).unsqueeze(1)\n", | |
" y = nn_model(x).round().int().squeeze() \n", | |
" assert y.data[0] == 1 and y.data.sum() == 1\n", | |
" \n", | |
" x = Variable(torch.FloatTensor([1])).unsqueeze(1)\n", | |
" y = nn_model(x).round().int().squeeze()\n", | |
" assert y.data[1] == 1 and y.data.sum() == 1\n", | |
" \n", | |
" x = Variable(torch.FloatTensor([0])).unsqueeze(1)\n", | |
" y = nn_model(x).round().int().squeeze()\n", | |
" assert y.data[1] == 1 and y.data.sum() == 1\n", | |
" \n", | |
" x = Variable(torch.FloatTensor([1])).unsqueeze(1)\n", | |
" y = nn_model(x).round().int().squeeze()\n", | |
" assert y.data[1] == 1 and y.data.sum() == 1\n", | |
" \n", | |
" x = Variable(torch.FloatTensor([1])).unsqueeze(1)\n", | |
" y = nn_model(x).round().int().squeeze()\n", | |
" assert y.data[2] == 1 and y.data.sum() == 1\n", | |
"\n", | |
" \n", | |
"\n", | |
"check_model()\n", | |
"print(\"All done\")" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 3", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.6.3" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 2 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment