Skip to content

Instantly share code, notes, and snippets.

@mirsahib
Last active March 21, 2020 05:08
Show Gist options
  • Save mirsahib/adcf81ba6986a2ffc1f449e71fe417fe to your computer and use it in GitHub Desktop.
Save mirsahib/adcf81ba6986a2ffc1f449e71fe417fe to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 38,
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"import torchvision\n",
"from torchvision import transforms,datasets"
]
},
{
"cell_type": "code",
"execution_count": 39,
"metadata": {},
"outputs": [],
"source": [
"import pandas as pd\n",
"import numpy as np\n",
"from sklearn.model_selection import train_test_split\n",
"from torch import nn, optim\n",
"import torch.nn.functional as F\n",
"import pdb"
]
},
{
"cell_type": "code",
"execution_count": 75,
"metadata": {},
"outputs": [],
"source": [
"data = pd.read_csv('Joint_Dataset.csv')"
]
},
{
"cell_type": "code",
"execution_count": 76,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>left_shoulder_X</th>\n",
" <th>left_shoulder_Y</th>\n",
" <th>left_shoulder_Z</th>\n",
" <th>left_elbow_X</th>\n",
" <th>left_elbow_Y</th>\n",
" <th>left_elbow_Z</th>\n",
" <th>left_wrist_X</th>\n",
" <th>left_wrist_Y</th>\n",
" <th>left_wrist_Z</th>\n",
" <th>left_hand_X</th>\n",
" <th>...</th>\n",
" <th>right_knee_X</th>\n",
" <th>right_knee_Y</th>\n",
" <th>right_knee_Z</th>\n",
" <th>right_ankle_X</th>\n",
" <th>right_ankle_Y</th>\n",
" <th>right_ankle_Z</th>\n",
" <th>right_foot_X</th>\n",
" <th>right_foot_Y</th>\n",
" <th>right_foot_Z</th>\n",
" <th>label</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <td>0</td>\n",
" <td>120.0</td>\n",
" <td>80.0</td>\n",
" <td>-269.0</td>\n",
" <td>114.0</td>\n",
" <td>108.0</td>\n",
" <td>-267.0</td>\n",
" <td>116.0</td>\n",
" <td>139.0</td>\n",
" <td>-255.0</td>\n",
" <td>115.0</td>\n",
" <td>...</td>\n",
" <td>153.0</td>\n",
" <td>182.0</td>\n",
" <td>-276.0</td>\n",
" <td>154.0</td>\n",
" <td>225.0</td>\n",
" <td>-285.0</td>\n",
" <td>153.0</td>\n",
" <td>234.0</td>\n",
" <td>-243.0</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <td>1</td>\n",
" <td>121.0</td>\n",
" <td>81.0</td>\n",
" <td>-268.0</td>\n",
" <td>114.0</td>\n",
" <td>109.0</td>\n",
" <td>-266.0</td>\n",
" <td>115.0</td>\n",
" <td>140.0</td>\n",
" <td>-255.0</td>\n",
" <td>117.0</td>\n",
" <td>...</td>\n",
" <td>153.0</td>\n",
" <td>182.0</td>\n",
" <td>-276.0</td>\n",
" <td>154.0</td>\n",
" <td>226.0</td>\n",
" <td>-284.0</td>\n",
" <td>154.0</td>\n",
" <td>234.0</td>\n",
" <td>-242.0</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <td>2</td>\n",
" <td>120.0</td>\n",
" <td>81.0</td>\n",
" <td>-269.0</td>\n",
" <td>114.0</td>\n",
" <td>109.0</td>\n",
" <td>-266.0</td>\n",
" <td>115.0</td>\n",
" <td>140.0</td>\n",
" <td>-257.0</td>\n",
" <td>117.0</td>\n",
" <td>...</td>\n",
" <td>153.0</td>\n",
" <td>182.0</td>\n",
" <td>-276.0</td>\n",
" <td>155.0</td>\n",
" <td>226.0</td>\n",
" <td>-283.0</td>\n",
" <td>155.0</td>\n",
" <td>234.0</td>\n",
" <td>-241.0</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <td>3</td>\n",
" <td>120.0</td>\n",
" <td>81.0</td>\n",
" <td>-269.0</td>\n",
" <td>113.0</td>\n",
" <td>109.0</td>\n",
" <td>-266.0</td>\n",
" <td>115.0</td>\n",
" <td>140.0</td>\n",
" <td>-255.0</td>\n",
" <td>116.0</td>\n",
" <td>...</td>\n",
" <td>153.0</td>\n",
" <td>182.0</td>\n",
" <td>-275.0</td>\n",
" <td>154.0</td>\n",
" <td>225.0</td>\n",
" <td>-286.0</td>\n",
" <td>154.0</td>\n",
" <td>234.0</td>\n",
" <td>-244.0</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <td>4</td>\n",
" <td>120.0</td>\n",
" <td>81.0</td>\n",
" <td>-269.0</td>\n",
" <td>114.0</td>\n",
" <td>109.0</td>\n",
" <td>-267.0</td>\n",
" <td>115.0</td>\n",
" <td>140.0</td>\n",
" <td>-255.0</td>\n",
" <td>117.0</td>\n",
" <td>...</td>\n",
" <td>153.0</td>\n",
" <td>182.0</td>\n",
" <td>-275.0</td>\n",
" <td>154.0</td>\n",
" <td>225.0</td>\n",
" <td>-285.0</td>\n",
" <td>154.0</td>\n",
" <td>234.0</td>\n",
" <td>-243.0</td>\n",
" <td>1</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>5 rows × 61 columns</p>\n",
"</div>"
],
"text/plain": [
" left_shoulder_X left_shoulder_Y left_shoulder_Z left_elbow_X \\\n",
"0 120.0 80.0 -269.0 114.0 \n",
"1 121.0 81.0 -268.0 114.0 \n",
"2 120.0 81.0 -269.0 114.0 \n",
"3 120.0 81.0 -269.0 113.0 \n",
"4 120.0 81.0 -269.0 114.0 \n",
"\n",
" left_elbow_Y left_elbow_Z left_wrist_X left_wrist_Y left_wrist_Z \\\n",
"0 108.0 -267.0 116.0 139.0 -255.0 \n",
"1 109.0 -266.0 115.0 140.0 -255.0 \n",
"2 109.0 -266.0 115.0 140.0 -257.0 \n",
"3 109.0 -266.0 115.0 140.0 -255.0 \n",
"4 109.0 -267.0 115.0 140.0 -255.0 \n",
"\n",
" left_hand_X ... right_knee_X right_knee_Y right_knee_Z right_ankle_X \\\n",
"0 115.0 ... 153.0 182.0 -276.0 154.0 \n",
"1 117.0 ... 153.0 182.0 -276.0 154.0 \n",
"2 117.0 ... 153.0 182.0 -276.0 155.0 \n",
"3 116.0 ... 153.0 182.0 -275.0 154.0 \n",
"4 117.0 ... 153.0 182.0 -275.0 154.0 \n",
"\n",
" right_ankle_Y right_ankle_Z right_foot_X right_foot_Y right_foot_Z \\\n",
"0 225.0 -285.0 153.0 234.0 -243.0 \n",
"1 226.0 -284.0 154.0 234.0 -242.0 \n",
"2 226.0 -283.0 155.0 234.0 -241.0 \n",
"3 225.0 -286.0 154.0 234.0 -244.0 \n",
"4 225.0 -285.0 154.0 234.0 -243.0 \n",
"\n",
" label \n",
"0 1 \n",
"1 1 \n",
"2 1 \n",
"3 1 \n",
"4 1 \n",
"\n",
"[5 rows x 61 columns]"
]
},
"execution_count": 76,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"data.head()"
]
},
{
"cell_type": "code",
"execution_count": 77,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"((22797, 60), (22797,))"
]
},
"execution_count": 77,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"X = data.drop('label',axis=1)\n",
"Y = data['label']\n",
"X.shape,Y.shape"
]
},
{
"cell_type": "code",
"execution_count": 78,
"metadata": {},
"outputs": [],
"source": [
"X_train, X_test, y_train, y_test = train_test_split(X, Y, test_size=0.2, random_state=42)"
]
},
{
"cell_type": "code",
"execution_count": 79,
"metadata": {},
"outputs": [],
"source": [
"X_train = torch.from_numpy(X_train.to_numpy()).float()\n",
"#y_train = torch.squeeze(torch.from_numpy(y_train.to_numpy()).float())\n",
"y_th_train = torch.from_numpy(y_train.to_numpy()).type(torch.LongTensor)\n",
"y_th_train = y_th_train-1\n",
"\n",
"X_test = torch.from_numpy(X_test.to_numpy()).float()\n",
"#y_test = torch.squeeze(torch.from_numpy(y_test.to_numpy()).float())\n",
"y_th_test = torch.from_numpy(y_test.to_numpy()).type(torch.LongTensor)\n",
"y_th_test=y_th_test-1\n",
"\n",
"#rint(X_train.shape, y_train.shape)\n",
"#rint(X_test.shape, y_test.shape)"
]
},
{
"cell_type": "code",
"execution_count": 81,
"metadata": {},
"outputs": [],
"source": [
"class Net(nn.Module):\n",
" def __init__(self, n_features):\n",
" super(Net, self).__init__()\n",
" self.fc1 = nn.Linear(n_features, 40)\n",
" self.fc2 = nn.Linear(40, 40)\n",
" self.fc3 = nn.Linear(40, 40)\n",
" self.fc4 = nn.Linear(40,40)\n",
" self.fc5 = nn.Linear(40,20)\n",
" def forward(self, x):\n",
" x = F.relu(self.fc1(x))\n",
" x = F.relu(self.fc2(x))\n",
" x = F.relu(self.fc3(x))\n",
" x = F.relu(self.fc4(x))\n",
" x = self.fc5(x)\n",
" return F.log_softmax(x,dim=1)"
]
},
{
"cell_type": "code",
"execution_count": 82,
"metadata": {},
"outputs": [],
"source": [
"net = Net(X_train.shape[1])\n",
"#ann_viz(net, view=False)"
]
},
{
"cell_type": "code",
"execution_count": 83,
"metadata": {},
"outputs": [],
"source": [
"criterion = nn.CrossEntropyLoss()\n",
"optimizer = optim.Adam(net.parameters(),lr=0.001)"
]
},
{
"cell_type": "code",
"execution_count": 84,
"metadata": {},
"outputs": [],
"source": [
"def calculate_accuracy(y_true, y_pred):\n",
" predicted = y_pred.ge(.5).view(-1)\n",
" return (y_true == predicted).sum().float() / len(y_true)"
]
},
{
"cell_type": "code",
"execution_count": 85,
"metadata": {},
"outputs": [],
"source": [
"def round_tensor(t, decimal_places=3):\n",
" return round(t.item(), decimal_places)"
]
},
{
"cell_type": "code",
"execution_count": 86,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"epoch 0 Train set - loss: 19650.602, accuracy: 5.2 Test set - loss: 17926.424, accuracy: 4.8 \n",
"epoch 100 Train set - loss: 140.026, accuracy: 6.1 Test set - loss: 179.069, accuracy: 5.2 \n",
"epoch 200 Train set - loss: 165.377, accuracy: 7.3 Test set - loss: 147.435, accuracy: 6.4 \n",
"epoch 300 Train set - loss: 148.573, accuracy: 7.3 Test set - loss: 110.559, accuracy: 6.5 \n",
"epoch 400 Train set - loss: 89.973, accuracy: 7.3999999999999995 Test set - loss: 77.384, accuracy: 6.5 \n",
"epoch 500 Train set - loss: 52.287, accuracy: 7.6 Test set - loss: 62.641, accuracy: 6.4 \n",
"epoch 600 Train set - loss: 68.02, accuracy: 7.3999999999999995 Test set - loss: 81.178, accuracy: 6.5 \n",
"epoch 700 Train set - loss: 50.873, accuracy: 7.3999999999999995 Test set - loss: 39.847, accuracy: 6.5 \n",
"epoch 800 Train set - loss: 14.206, accuracy: 7.3999999999999995 Test set - loss: 12.841, accuracy: 6.4 \n",
"epoch 900 Train set - loss: 34.583, accuracy: 7.5 Test set - loss: 26.926, accuracy: 6.5 \n"
]
}
],
"source": [
"\n",
"for epoch in range(1000):\n",
" y_pred = net(X_train)\n",
" #pdb.set_trace()\n",
" #optimizer.zero_grad()\n",
" train_loss = criterion(y_pred,y_th_train)\n",
" if epoch % 100 == 0:\n",
" #pdb.set_trace()\n",
" y_pred = y_pred.argmax(dim=1)\n",
" train_acc = calculate_accuracy(y_th_train, y_pred)\n",
" #pdb.set_trace()\n",
" y_test_pred = net(X_test)\n",
" test_loss = criterion(y_test_pred, y_th_test)\n",
" y_test_pred = y_test_pred.argmax(dim=1)\n",
" test_acc = calculate_accuracy(y_th_test, y_test_pred)\n",
" print(f'''epoch {epoch} Train set - loss: {round_tensor(train_loss)}, accuracy: {round_tensor(train_acc)*100} Test set - loss: {round_tensor(test_loss)}, accuracy: {round_tensor(test_acc)*100} ''')\n",
" \n",
" optimizer.zero_grad()\n",
" train_loss.backward()\n",
" optimizer.step()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.4"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment