Last active
March 21, 2020 05:08
-
-
Save mirsahib/adcf81ba6986a2ffc1f449e71fe417fe to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "code", | |
"execution_count": 38, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"import torch\n", | |
"import torchvision\n", | |
"from torchvision import transforms,datasets" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 39, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"import pandas as pd\n", | |
"import numpy as np\n", | |
"from sklearn.model_selection import train_test_split\n", | |
"from torch import nn, optim\n", | |
"import torch.nn.functional as F\n", | |
"import pdb" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 75, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"data = pd.read_csv('Joint_Dataset.csv')" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 76, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/html": [ | |
"<div>\n", | |
"<style scoped>\n", | |
" .dataframe tbody tr th:only-of-type {\n", | |
" vertical-align: middle;\n", | |
" }\n", | |
"\n", | |
" .dataframe tbody tr th {\n", | |
" vertical-align: top;\n", | |
" }\n", | |
"\n", | |
" .dataframe thead th {\n", | |
" text-align: right;\n", | |
" }\n", | |
"</style>\n", | |
"<table border=\"1\" class=\"dataframe\">\n", | |
" <thead>\n", | |
" <tr style=\"text-align: right;\">\n", | |
" <th></th>\n", | |
" <th>left_shoulder_X</th>\n", | |
" <th>left_shoulder_Y</th>\n", | |
" <th>left_shoulder_Z</th>\n", | |
" <th>left_elbow_X</th>\n", | |
" <th>left_elbow_Y</th>\n", | |
" <th>left_elbow_Z</th>\n", | |
" <th>left_wrist_X</th>\n", | |
" <th>left_wrist_Y</th>\n", | |
" <th>left_wrist_Z</th>\n", | |
" <th>left_hand_X</th>\n", | |
" <th>...</th>\n", | |
" <th>right_knee_X</th>\n", | |
" <th>right_knee_Y</th>\n", | |
" <th>right_knee_Z</th>\n", | |
" <th>right_ankle_X</th>\n", | |
" <th>right_ankle_Y</th>\n", | |
" <th>right_ankle_Z</th>\n", | |
" <th>right_foot_X</th>\n", | |
" <th>right_foot_Y</th>\n", | |
" <th>right_foot_Z</th>\n", | |
" <th>label</th>\n", | |
" </tr>\n", | |
" </thead>\n", | |
" <tbody>\n", | |
" <tr>\n", | |
" <td>0</td>\n", | |
" <td>120.0</td>\n", | |
" <td>80.0</td>\n", | |
" <td>-269.0</td>\n", | |
" <td>114.0</td>\n", | |
" <td>108.0</td>\n", | |
" <td>-267.0</td>\n", | |
" <td>116.0</td>\n", | |
" <td>139.0</td>\n", | |
" <td>-255.0</td>\n", | |
" <td>115.0</td>\n", | |
" <td>...</td>\n", | |
" <td>153.0</td>\n", | |
" <td>182.0</td>\n", | |
" <td>-276.0</td>\n", | |
" <td>154.0</td>\n", | |
" <td>225.0</td>\n", | |
" <td>-285.0</td>\n", | |
" <td>153.0</td>\n", | |
" <td>234.0</td>\n", | |
" <td>-243.0</td>\n", | |
" <td>1</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <td>1</td>\n", | |
" <td>121.0</td>\n", | |
" <td>81.0</td>\n", | |
" <td>-268.0</td>\n", | |
" <td>114.0</td>\n", | |
" <td>109.0</td>\n", | |
" <td>-266.0</td>\n", | |
" <td>115.0</td>\n", | |
" <td>140.0</td>\n", | |
" <td>-255.0</td>\n", | |
" <td>117.0</td>\n", | |
" <td>...</td>\n", | |
" <td>153.0</td>\n", | |
" <td>182.0</td>\n", | |
" <td>-276.0</td>\n", | |
" <td>154.0</td>\n", | |
" <td>226.0</td>\n", | |
" <td>-284.0</td>\n", | |
" <td>154.0</td>\n", | |
" <td>234.0</td>\n", | |
" <td>-242.0</td>\n", | |
" <td>1</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <td>2</td>\n", | |
" <td>120.0</td>\n", | |
" <td>81.0</td>\n", | |
" <td>-269.0</td>\n", | |
" <td>114.0</td>\n", | |
" <td>109.0</td>\n", | |
" <td>-266.0</td>\n", | |
" <td>115.0</td>\n", | |
" <td>140.0</td>\n", | |
" <td>-257.0</td>\n", | |
" <td>117.0</td>\n", | |
" <td>...</td>\n", | |
" <td>153.0</td>\n", | |
" <td>182.0</td>\n", | |
" <td>-276.0</td>\n", | |
" <td>155.0</td>\n", | |
" <td>226.0</td>\n", | |
" <td>-283.0</td>\n", | |
" <td>155.0</td>\n", | |
" <td>234.0</td>\n", | |
" <td>-241.0</td>\n", | |
" <td>1</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <td>3</td>\n", | |
" <td>120.0</td>\n", | |
" <td>81.0</td>\n", | |
" <td>-269.0</td>\n", | |
" <td>113.0</td>\n", | |
" <td>109.0</td>\n", | |
" <td>-266.0</td>\n", | |
" <td>115.0</td>\n", | |
" <td>140.0</td>\n", | |
" <td>-255.0</td>\n", | |
" <td>116.0</td>\n", | |
" <td>...</td>\n", | |
" <td>153.0</td>\n", | |
" <td>182.0</td>\n", | |
" <td>-275.0</td>\n", | |
" <td>154.0</td>\n", | |
" <td>225.0</td>\n", | |
" <td>-286.0</td>\n", | |
" <td>154.0</td>\n", | |
" <td>234.0</td>\n", | |
" <td>-244.0</td>\n", | |
" <td>1</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <td>4</td>\n", | |
" <td>120.0</td>\n", | |
" <td>81.0</td>\n", | |
" <td>-269.0</td>\n", | |
" <td>114.0</td>\n", | |
" <td>109.0</td>\n", | |
" <td>-267.0</td>\n", | |
" <td>115.0</td>\n", | |
" <td>140.0</td>\n", | |
" <td>-255.0</td>\n", | |
" <td>117.0</td>\n", | |
" <td>...</td>\n", | |
" <td>153.0</td>\n", | |
" <td>182.0</td>\n", | |
" <td>-275.0</td>\n", | |
" <td>154.0</td>\n", | |
" <td>225.0</td>\n", | |
" <td>-285.0</td>\n", | |
" <td>154.0</td>\n", | |
" <td>234.0</td>\n", | |
" <td>-243.0</td>\n", | |
" <td>1</td>\n", | |
" </tr>\n", | |
" </tbody>\n", | |
"</table>\n", | |
"<p>5 rows × 61 columns</p>\n", | |
"</div>" | |
], | |
"text/plain": [ | |
" left_shoulder_X left_shoulder_Y left_shoulder_Z left_elbow_X \\\n", | |
"0 120.0 80.0 -269.0 114.0 \n", | |
"1 121.0 81.0 -268.0 114.0 \n", | |
"2 120.0 81.0 -269.0 114.0 \n", | |
"3 120.0 81.0 -269.0 113.0 \n", | |
"4 120.0 81.0 -269.0 114.0 \n", | |
"\n", | |
" left_elbow_Y left_elbow_Z left_wrist_X left_wrist_Y left_wrist_Z \\\n", | |
"0 108.0 -267.0 116.0 139.0 -255.0 \n", | |
"1 109.0 -266.0 115.0 140.0 -255.0 \n", | |
"2 109.0 -266.0 115.0 140.0 -257.0 \n", | |
"3 109.0 -266.0 115.0 140.0 -255.0 \n", | |
"4 109.0 -267.0 115.0 140.0 -255.0 \n", | |
"\n", | |
" left_hand_X ... right_knee_X right_knee_Y right_knee_Z right_ankle_X \\\n", | |
"0 115.0 ... 153.0 182.0 -276.0 154.0 \n", | |
"1 117.0 ... 153.0 182.0 -276.0 154.0 \n", | |
"2 117.0 ... 153.0 182.0 -276.0 155.0 \n", | |
"3 116.0 ... 153.0 182.0 -275.0 154.0 \n", | |
"4 117.0 ... 153.0 182.0 -275.0 154.0 \n", | |
"\n", | |
" right_ankle_Y right_ankle_Z right_foot_X right_foot_Y right_foot_Z \\\n", | |
"0 225.0 -285.0 153.0 234.0 -243.0 \n", | |
"1 226.0 -284.0 154.0 234.0 -242.0 \n", | |
"2 226.0 -283.0 155.0 234.0 -241.0 \n", | |
"3 225.0 -286.0 154.0 234.0 -244.0 \n", | |
"4 225.0 -285.0 154.0 234.0 -243.0 \n", | |
"\n", | |
" label \n", | |
"0 1 \n", | |
"1 1 \n", | |
"2 1 \n", | |
"3 1 \n", | |
"4 1 \n", | |
"\n", | |
"[5 rows x 61 columns]" | |
] | |
}, | |
"execution_count": 76, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"data.head()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 77, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"((22797, 60), (22797,))" | |
] | |
}, | |
"execution_count": 77, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"X = data.drop('label',axis=1)\n", | |
"Y = data['label']\n", | |
"X.shape,Y.shape" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 78, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"X_train, X_test, y_train, y_test = train_test_split(X, Y, test_size=0.2, random_state=42)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 79, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"X_train = torch.from_numpy(X_train.to_numpy()).float()\n", | |
"#y_train = torch.squeeze(torch.from_numpy(y_train.to_numpy()).float())\n", | |
"y_th_train = torch.from_numpy(y_train.to_numpy()).type(torch.LongTensor)\n", | |
"y_th_train = y_th_train-1\n", | |
"\n", | |
"X_test = torch.from_numpy(X_test.to_numpy()).float()\n", | |
"#y_test = torch.squeeze(torch.from_numpy(y_test.to_numpy()).float())\n", | |
"y_th_test = torch.from_numpy(y_test.to_numpy()).type(torch.LongTensor)\n", | |
"y_th_test=y_th_test-1\n", | |
"\n", | |
"#rint(X_train.shape, y_train.shape)\n", | |
"#rint(X_test.shape, y_test.shape)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 81, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"class Net(nn.Module):\n", | |
" def __init__(self, n_features):\n", | |
" super(Net, self).__init__()\n", | |
" self.fc1 = nn.Linear(n_features, 40)\n", | |
" self.fc2 = nn.Linear(40, 40)\n", | |
" self.fc3 = nn.Linear(40, 40)\n", | |
" self.fc4 = nn.Linear(40,40)\n", | |
" self.fc5 = nn.Linear(40,20)\n", | |
" def forward(self, x):\n", | |
" x = F.relu(self.fc1(x))\n", | |
" x = F.relu(self.fc2(x))\n", | |
" x = F.relu(self.fc3(x))\n", | |
" x = F.relu(self.fc4(x))\n", | |
" x = self.fc5(x)\n", | |
" return F.log_softmax(x,dim=1)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 82, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"net = Net(X_train.shape[1])\n", | |
"#ann_viz(net, view=False)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 83, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"criterion = nn.CrossEntropyLoss()\n", | |
"optimizer = optim.Adam(net.parameters(),lr=0.001)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 84, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"def calculate_accuracy(y_true, y_pred):\n", | |
" predicted = y_pred.ge(.5).view(-1)\n", | |
" return (y_true == predicted).sum().float() / len(y_true)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 85, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"def round_tensor(t, decimal_places=3):\n", | |
" return round(t.item(), decimal_places)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 86, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"epoch 0 Train set - loss: 19650.602, accuracy: 5.2 Test set - loss: 17926.424, accuracy: 4.8 \n", | |
"epoch 100 Train set - loss: 140.026, accuracy: 6.1 Test set - loss: 179.069, accuracy: 5.2 \n", | |
"epoch 200 Train set - loss: 165.377, accuracy: 7.3 Test set - loss: 147.435, accuracy: 6.4 \n", | |
"epoch 300 Train set - loss: 148.573, accuracy: 7.3 Test set - loss: 110.559, accuracy: 6.5 \n", | |
"epoch 400 Train set - loss: 89.973, accuracy: 7.3999999999999995 Test set - loss: 77.384, accuracy: 6.5 \n", | |
"epoch 500 Train set - loss: 52.287, accuracy: 7.6 Test set - loss: 62.641, accuracy: 6.4 \n", | |
"epoch 600 Train set - loss: 68.02, accuracy: 7.3999999999999995 Test set - loss: 81.178, accuracy: 6.5 \n", | |
"epoch 700 Train set - loss: 50.873, accuracy: 7.3999999999999995 Test set - loss: 39.847, accuracy: 6.5 \n", | |
"epoch 800 Train set - loss: 14.206, accuracy: 7.3999999999999995 Test set - loss: 12.841, accuracy: 6.4 \n", | |
"epoch 900 Train set - loss: 34.583, accuracy: 7.5 Test set - loss: 26.926, accuracy: 6.5 \n" | |
] | |
} | |
], | |
"source": [ | |
"\n", | |
"for epoch in range(1000):\n", | |
" y_pred = net(X_train)\n", | |
" #pdb.set_trace()\n", | |
" #optimizer.zero_grad()\n", | |
" train_loss = criterion(y_pred,y_th_train)\n", | |
" if epoch % 100 == 0:\n", | |
" #pdb.set_trace()\n", | |
" y_pred = y_pred.argmax(dim=1)\n", | |
" train_acc = calculate_accuracy(y_th_train, y_pred)\n", | |
" #pdb.set_trace()\n", | |
" y_test_pred = net(X_test)\n", | |
" test_loss = criterion(y_test_pred, y_th_test)\n", | |
" y_test_pred = y_test_pred.argmax(dim=1)\n", | |
" test_acc = calculate_accuracy(y_th_test, y_test_pred)\n", | |
" print(f'''epoch {epoch} Train set - loss: {round_tensor(train_loss)}, accuracy: {round_tensor(train_acc)*100} Test set - loss: {round_tensor(test_loss)}, accuracy: {round_tensor(test_acc)*100} ''')\n", | |
" \n", | |
" optimizer.zero_grad()\n", | |
" train_loss.backward()\n", | |
" optimizer.step()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 3", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.7.4" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 2 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment