Skip to content

Instantly share code, notes, and snippets.

@vinupriyesh
Last active February 6, 2018 12:38
Show Gist options
  • Save vinupriyesh/f30fd64b85a6026685330909e16930d4 to your computer and use it in GitHub Desktop.
Save vinupriyesh/f30fd64b85a6026685330909e16930d4 to your computer and use it in GitHub Desktop.
Basic RNN using pytorch
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Problem statement\n",
"\n",
"Same as https://gist.github.com/vinupriyesh/c764d26100e127d0d0f434b1c5b2cd51 but to design using RNN"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"from torch import nn\n",
"from torch.autograd import Variable"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Our model this time will be using RNN, here h0 is attached to the next input node"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"class SimpleRNNModule(nn.Module):\n",
" def __init__(self):\n",
" super(SimpleRNNModule,self).__init__()\n",
" self.fc1 = nn.Linear(4,3)\n",
" self.fc2 = nn.Linear(3,3)\n",
" self.h0 = Variable(torch.zeros(1,3))\n",
" \n",
" def forward(self,x):\n",
" x_effective = torch.cat((x,self.h0),1)\n",
" z1 = self.fc1(x_effective)\n",
" a1 = torch.tanh(z1)\n",
" self.h0 = a1\n",
" z2 = self.fc2(a1)\n",
" a2 = torch.sigmoid(z2)\n",
" return a2 \n",
" \n",
" def reset(self):\n",
" self.h0 = Variable(torch.zeros(1,3))\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Model is ready, next is input and output. Here the output is in a sequence so it will be different than the previous example"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"def generate_input():\n",
" x = torch.Tensor(1, 1).uniform_(0, 1)\n",
" return torch.bernoulli(x) \n",
"\n",
"\n",
"def generate_output(x):\n",
" y = torch.zeros(1,3)\n",
" sum_x = x.sum()\n",
" if sum_x == 0:\n",
" y[0,0] = 1\n",
" elif sum_x == 1:\n",
" y[0,1] = 1\n",
" else:\n",
" y[0,2] = 1\n",
" return y"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"\n",
" 1\n",
"[torch.FloatTensor of size 1x1]"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"generate_input()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"can create the model, the loss and the loop"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Loss on iteration 0 is 0.8868908286094666\n",
"Loss on iteration 100 is 0.6027839779853821\n",
"Loss on iteration 200 is 0.06013947352766991\n",
"Loss on iteration 300 is 0.14052194356918335\n",
"Loss on iteration 400 is 0.0651470199227333\n",
"Loss on iteration 500 is 0.02559574507176876\n",
"Loss on iteration 600 is 0.021827315911650658\n",
"Loss on iteration 700 is 0.004783849697560072\n",
"Loss on iteration 800 is 0.011788919568061829\n",
"Loss on iteration 900 is 0.004893178585916758\n"
]
}
],
"source": [
"nn_model = SimpleRNNModule()\n",
"criterion = nn.BCELoss() #Binary Cross Entropy Loss\n",
"optimizer = torch.optim.Adam(nn_model.parameters(),lr = 0.1)\n",
"optimizer.zero_grad()\n",
"\n",
"x_prev = torch.zeros(1,1)\n",
"\n",
"for i in range(1000):\n",
" x = generate_input()\n",
" y = generate_output(torch.cat((x_prev,x),1))\n",
" x = Variable(x)\n",
" y = Variable(y)\n",
" y_hat = nn_model(x)\n",
" loss = criterion(y_hat,y)\n",
" optimizer.zero_grad()\n",
" loss.backward(retain_graph=True) \n",
" optimizer.step()\n",
" x_prev = x.data\n",
" if i % 100 == 0:\n",
" print(\"Loss on iteration {} is {}\".format(i,loss.data[0])) "
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Check how did this perform"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"All done\n"
]
}
],
"source": [
"def check_model():\n",
" nn_model.reset()\n",
" \n",
" x = Variable(torch.FloatTensor([0])).unsqueeze(1)\n",
" y = nn_model(x).round().int().squeeze() \n",
" assert y.data[0] == 1 and y.data.sum() == 1\n",
" \n",
" x = Variable(torch.FloatTensor([1])).unsqueeze(1)\n",
" y = nn_model(x).round().int().squeeze()\n",
" assert y.data[1] == 1 and y.data.sum() == 1\n",
" \n",
" x = Variable(torch.FloatTensor([0])).unsqueeze(1)\n",
" y = nn_model(x).round().int().squeeze()\n",
" assert y.data[1] == 1 and y.data.sum() == 1\n",
" \n",
" x = Variable(torch.FloatTensor([1])).unsqueeze(1)\n",
" y = nn_model(x).round().int().squeeze()\n",
" assert y.data[1] == 1 and y.data.sum() == 1\n",
" \n",
" x = Variable(torch.FloatTensor([1])).unsqueeze(1)\n",
" y = nn_model(x).round().int().squeeze()\n",
" assert y.data[2] == 1 and y.data.sum() == 1\n",
"\n",
" \n",
"\n",
"check_model()\n",
"print(\"All done\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.3"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment