Skip to content

Instantly share code, notes, and snippets.

@firmai
Created April 11, 2022 19:48
Show Gist options
  • Save firmai/70b394526ca7499bb27633565580aaa2 to your computer and use it in GitHub Desktop.
Save firmai/70b394526ca7499bb27633565580aaa2 to your computer and use it in GitHub Desktop.
DeepLOB.ipynb
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
},
"source": [
"<a href=\"https://colab.research.google.com/gist/firmai/70b394526ca7499bb27633565580aaa2/deeplob.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "C6iAiQKNBpjN"
},
"source": [
"# DeepLOB: Deep Convolutional Neural Networks for Limit Order Books\n",
"\n",
"### Authors: Zihao Zhang, Stefan Zohren and Stephen Roberts\n",
"Oxford-Man Institute of Quantitative Finance, Department of Engineering Science, University of Oxford\n",
"\n",
"Adapted: Derek Snow (Oxford-Man Associate)\n",
"\n",
"This jupyter notebook is used to demonstrate our recent paper [2] published in IEEE Transactions on Singal Processing. We use FI-2010 [1] dataset and present how model architecture is constructed here. "
],
"id": "C6iAiQKNBpjN"
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "ZgNeWIkiBpjQ",
"outputId": "6f5464bf-5f5b-4e43-b04c-0627f6ab4656",
"colab": {
"base_uri": "https://localhost:8080/"
}
},
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"--2022-04-08 21:25:46-- https://raw.githubusercontent.com/zcakhaa/DeepLOB-Deep-Convolutional-Neural-Networks-for-Limit-Order-Books/master/data/data.zip\n",
"Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.110.133, 185.199.111.133, 185.199.108.133, ...\n",
"Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.110.133|:443... connected.\n",
"HTTP request sent, awaiting response... 200 OK\n",
"Length: 56278154 (54M) [application/zip]\n",
"Saving to: ‘data.zip’\n",
"\n",
"data.zip 100%[===================>] 53.67M 182MB/s in 0.3s \n",
"\n",
"2022-04-08 21:25:47 (182 MB/s) - ‘data.zip’ saved [56278154/56278154]\n",
"\n",
"Archive: data.zip\n",
" inflating: Test_Dst_NoAuction_DecPre_CF_7.txt \n",
" inflating: Test_Dst_NoAuction_DecPre_CF_9.txt \n",
" inflating: Test_Dst_NoAuction_DecPre_CF_8.txt \n",
" inflating: Train_Dst_NoAuction_DecPre_CF_7.txt \n",
"data downloaded.\n"
]
}
],
"source": [
"## Warning this Notebook would only work with paid Google Colab Pro Subscription\n",
"import os \n",
"if not os.path.isfile('data.zip'):\n",
" !wget https://raw.githubusercontent.com/zcakhaa/DeepLOB-Deep-Convolutional-Neural-Networks-for-Limit-Order-Books/master/data/data.zip\n",
" !unzip -n data.zip\n",
" print('data downloaded.')\n",
"else:\n",
" print('data already existed.')"
],
"id": "ZgNeWIkiBpjQ"
},
{
"cell_type": "code",
"source": [
"!pip install torchinfo"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "BsaASxFrCXzY",
"outputId": "e9fca188-a4d7-4797-ab96-bbd26a8418c9"
},
"id": "BsaASxFrCXzY",
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Collecting torchinfo\n",
" Downloading torchinfo-1.6.5-py3-none-any.whl (21 kB)\n",
"Installing collected packages: torchinfo\n",
"Successfully installed torchinfo-1.6.5\n"
]
}
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "1pfU-6EjBpjS"
},
"outputs": [],
"source": [
"# load packages\n",
"import pandas as pd\n",
"import pickle\n",
"import numpy as np\n",
"import matplotlib.pyplot as plt\n",
"from datetime import datetime\n",
"from tqdm import tqdm \n",
"from sklearn.metrics import accuracy_score, classification_report\n",
"\n",
"import torch\n",
"import torch.nn.functional as F\n",
"from torch.utils import data\n",
"from torchinfo import summary\n",
"import torch.nn as nn\n",
"import torch.optim as optim\n",
"# from sklearn.model_selection import train_test_split\n",
"# X_train, X_test, y_train, y_test = train_test_split(data.data, data.target, test_size=0.33)\n",
"# N, D = X_train.shape"
],
"id": "1pfU-6EjBpjS"
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "Hiwoh3LJBpjT",
"outputId": "22aa0fee-e3d4-4a2d-e5d3-cee0837e1a50",
"colab": {
"base_uri": "https://localhost:8080/"
}
},
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"cuda:0\n"
]
}
],
"source": [
"device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n",
"print(device)"
],
"id": "Hiwoh3LJBpjT"
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "avW8cN2MBpjU"
},
"outputs": [],
"source": [
"def prepare_x(data):\n",
" df1 = data[:40, :].T\n",
" return np.array(df1)\n",
"\n",
"def get_label(data):\n",
" lob = data[-5:, :].T\n",
" return lob\n",
"\n",
"def data_classification(X, Y, T):\n",
" [N, D] = X.shape\n",
" df = np.array(X)\n",
"\n",
" dY = np.array(Y)\n",
"\n",
" dataY = dY[T - 1:N]\n",
"\n",
" dataX = np.zeros((N - T + 1, T, D))\n",
" for i in range(T, N + 1):\n",
" dataX[i - T] = df[i - T:i, :]\n",
"\n",
" return dataX, dataY\n",
"\n",
"def torch_data(x, y):\n",
" x = torch.from_numpy(x)\n",
" x = torch.unsqueeze(x, 1)\n",
" y = torch.from_numpy(y)\n",
" y = F.one_hot(y, num_classes=3)\n",
" return x, y"
],
"id": "avW8cN2MBpjU"
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "LzMyPzCuBpjV"
},
"outputs": [],
"source": [
"class Dataset(data.Dataset):\n",
" \"\"\"Characterizes a dataset for PyTorch\"\"\"\n",
" def __init__(self, data, k, num_classes, T):\n",
" \"\"\"Initialization\"\"\" \n",
" self.k = k\n",
" self.num_classes = num_classes\n",
" self.T = T\n",
" \n",
" x = prepare_x(data)\n",
" y = get_label(data)\n",
" x, y = data_classification(x, y, self.T)\n",
" y = y[:,self.k] - 1\n",
" self.length = len(x)\n",
"\n",
" x = torch.from_numpy(x)\n",
" self.x = torch.unsqueeze(x, 1)\n",
" self.y = torch.from_numpy(y)\n",
"\n",
" def __len__(self):\n",
" \"\"\"Denotes the total number of samples\"\"\"\n",
" return self.length\n",
"\n",
" def __getitem__(self, index):\n",
" \"\"\"Generates samples of data\"\"\"\n",
" return self.x[index], self.y[index]"
],
"id": "LzMyPzCuBpjV"
},
{
"cell_type": "markdown",
"metadata": {
"id": "1dyC08ezBpjW"
},
"source": [
"# Data preparation\n",
"\n",
"The first seven days are training data and the last three days are testing data. A validation set (20%) from the training set is used to monitor the overfitting behaviours. \n",
"\n",
"The first 40 columns of the FI-2010 dataset are 10 levels of ask and bid information for a limit order book and we only use these 40 features in our network. The last 5 columns of the FI-2010 dataset are the labels with different prediction horizons. \n",
"\n",
"The FI-2010 is a publicly available dataset, and you can find more information in the following paper: \n",
"\n",
"Ntakaris A, Magris M, Kanniainen J, Gabbouj M, Iosifidis A. Benchmark dataset for mid‐price forecasting of limit order book data with machine learning methods. Journal of Forecasting. 2018 Dec;37(8):852-66. https://arxiv.org/abs/1705.03233"
],
"id": "1dyC08ezBpjW"
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "UtnH_1t7BpjW",
"outputId": "f5b102e8-2f02-4e43-be97-9ee2cef9f9d7",
"colab": {
"base_uri": "https://localhost:8080/"
}
},
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"(149, 203800) (149, 50950) (149, 139587)\n"
]
}
],
"source": [
"# please change the data_path to your local path\n",
"# data_path = '/nfs/home/zihaoz/limit_order_book/data'\n",
"\n",
"dec_data = np.loadtxt('Train_Dst_NoAuction_DecPre_CF_7.txt')\n",
"dec_train = dec_data[:, :int(np.floor(dec_data.shape[1] * 0.8))]\n",
"dec_val = dec_data[:, int(np.floor(dec_data.shape[1] * 0.8)):]\n",
"\n",
"dec_test1 = np.loadtxt('Test_Dst_NoAuction_DecPre_CF_7.txt')\n",
"dec_test2 = np.loadtxt('Test_Dst_NoAuction_DecPre_CF_8.txt')\n",
"dec_test3 = np.loadtxt('Test_Dst_NoAuction_DecPre_CF_9.txt')\n",
"dec_test = np.hstack((dec_test1, dec_test2, dec_test3))\n",
"\n",
"print(dec_train.shape, dec_val.shape, dec_test.shape)"
],
"id": "UtnH_1t7BpjW"
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "lWfCDp6zBpjX",
"outputId": "60eaca4e-59a4-4339-a942-dee1ac293113",
"colab": {
"base_uri": "https://localhost:8080/"
}
},
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"torch.Size([203701, 1, 100, 40]) torch.Size([203701])\n"
]
}
],
"source": [
"batch_size = 64\n",
"\n",
"dataset_train = Dataset(data=dec_train, k=4, num_classes=3, T=100)\n",
"dataset_val = Dataset(data=dec_val, k=4, num_classes=3, T=100)\n",
"dataset_test = Dataset(data=dec_test, k=4, num_classes=3, T=100)\n",
"\n",
"train_loader = torch.utils.data.DataLoader(dataset=dataset_train, batch_size=batch_size, shuffle=True)\n",
"val_loader = torch.utils.data.DataLoader(dataset=dataset_val, batch_size=batch_size, shuffle=False)\n",
"test_loader = torch.utils.data.DataLoader(dataset=dataset_test, batch_size=batch_size, shuffle=False)\n",
"\n",
"print(dataset_train.x.shape, dataset_train.y.shape)"
],
"id": "lWfCDp6zBpjX"
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"scrolled": true,
"id": "t-6zQiyRBpjY",
"outputId": "da9dcac5-3a4d-46ae-bb6b-9df14d26f50c",
"colab": {
"base_uri": "https://localhost:8080/"
}
},
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"tensor([[[[0.1254, 0.0146, 0.1251, ..., 0.0383, 0.1242, 0.2624],\n",
" [0.1254, 0.0116, 0.1251, ..., 0.0383, 0.1242, 0.2624],\n",
" [0.1253, 0.0030, 0.1251, ..., 0.0488, 0.1242, 0.2624],\n",
" ...,\n",
" [0.1254, 0.0043, 0.1251, ..., 0.0280, 0.1241, 0.0050],\n",
" [0.1254, 0.0043, 0.1251, ..., 0.0280, 0.1241, 0.0050],\n",
" [0.1253, 0.0030, 0.1251, ..., 0.0528, 0.1241, 0.0050]]]],\n",
" dtype=torch.float64)\n",
"tensor([2.], dtype=torch.float64)\n",
"torch.Size([1, 1, 100, 40]) torch.Size([1])\n"
]
}
],
"source": [
"tmp_loader = torch.utils.data.DataLoader(dataset=dataset_train, batch_size=1, shuffle=True)\n",
"\n",
"for x, y in tmp_loader:\n",
" print(x)\n",
" print(y)\n",
" print(x.shape, y.shape)\n",
" break"
],
"id": "t-6zQiyRBpjY"
},
{
"cell_type": "markdown",
"metadata": {
"id": "EaHztLafBpjY"
},
"source": [
"# Model Architecture\n",
"\n",
"Please find the detailed discussion of our model architecture in our paper."
],
"id": "EaHztLafBpjY"
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "6jRdgwGkBpjZ"
},
"outputs": [],
"source": [
"class deeplob(nn.Module):\n",
" def __init__(self, y_len):\n",
" super().__init__()\n",
" self.y_len = y_len\n",
" \n",
" # convolution blocks\n",
" self.conv1 = nn.Sequential(\n",
" nn.Conv2d(in_channels=1, out_channels=32, kernel_size=(1,2), stride=(1,2)),\n",
" nn.LeakyReLU(negative_slope=0.01),\n",
"# nn.Tanh(),\n",
" nn.BatchNorm2d(32),\n",
" nn.Conv2d(in_channels=32, out_channels=32, kernel_size=(4,1)),\n",
" nn.LeakyReLU(negative_slope=0.01),\n",
" nn.BatchNorm2d(32),\n",
" nn.Conv2d(in_channels=32, out_channels=32, kernel_size=(4,1)),\n",
" nn.LeakyReLU(negative_slope=0.01),\n",
" nn.BatchNorm2d(32),\n",
" )\n",
" self.conv2 = nn.Sequential(\n",
" nn.Conv2d(in_channels=32, out_channels=32, kernel_size=(1,2), stride=(1,2)),\n",
" nn.Tanh(),\n",
" nn.BatchNorm2d(32),\n",
" nn.Conv2d(in_channels=32, out_channels=32, kernel_size=(4,1)),\n",
" nn.Tanh(),\n",
" nn.BatchNorm2d(32),\n",
" nn.Conv2d(in_channels=32, out_channels=32, kernel_size=(4,1)),\n",
" nn.Tanh(),\n",
" nn.BatchNorm2d(32),\n",
" )\n",
" self.conv3 = nn.Sequential(\n",
" nn.Conv2d(in_channels=32, out_channels=32, kernel_size=(1,10)),\n",
" nn.LeakyReLU(negative_slope=0.01),\n",
" nn.BatchNorm2d(32),\n",
" nn.Conv2d(in_channels=32, out_channels=32, kernel_size=(4,1)),\n",
" nn.LeakyReLU(negative_slope=0.01),\n",
" nn.BatchNorm2d(32),\n",
" nn.Conv2d(in_channels=32, out_channels=32, kernel_size=(4,1)),\n",
" nn.LeakyReLU(negative_slope=0.01),\n",
" nn.BatchNorm2d(32),\n",
" )\n",
" \n",
" # inception moduels\n",
" self.inp1 = nn.Sequential(\n",
" nn.Conv2d(in_channels=32, out_channels=64, kernel_size=(1,1), padding='same'),\n",
" nn.LeakyReLU(negative_slope=0.01),\n",
" nn.BatchNorm2d(64),\n",
" nn.Conv2d(in_channels=64, out_channels=64, kernel_size=(3,1), padding='same'),\n",
" nn.LeakyReLU(negative_slope=0.01),\n",
" nn.BatchNorm2d(64),\n",
" )\n",
" self.inp2 = nn.Sequential(\n",
" nn.Conv2d(in_channels=32, out_channels=64, kernel_size=(1,1), padding='same'),\n",
" nn.LeakyReLU(negative_slope=0.01),\n",
" nn.BatchNorm2d(64),\n",
" nn.Conv2d(in_channels=64, out_channels=64, kernel_size=(5,1), padding='same'),\n",
" nn.LeakyReLU(negative_slope=0.01),\n",
" nn.BatchNorm2d(64),\n",
" )\n",
" self.inp3 = nn.Sequential(\n",
" nn.MaxPool2d((3, 1), stride=(1, 1), padding=(1, 0)),\n",
" nn.Conv2d(in_channels=32, out_channels=64, kernel_size=(1,1), padding='same'),\n",
" nn.LeakyReLU(negative_slope=0.01),\n",
" nn.BatchNorm2d(64),\n",
" )\n",
" \n",
" # lstm layers\n",
" self.lstm = nn.LSTM(input_size=192, hidden_size=64, num_layers=1, batch_first=True)\n",
" self.fc1 = nn.Linear(64, self.y_len)\n",
"\n",
" def forward(self, x):\n",
" # h0: (number of hidden layers, batch size, hidden size)\n",
" h0 = torch.zeros(1, x.size(0), 64).to(device)\n",
" c0 = torch.zeros(1, x.size(0), 64).to(device)\n",
" \n",
" x = self.conv1(x)\n",
" x = self.conv2(x)\n",
" x = self.conv3(x)\n",
" \n",
" x_inp1 = self.inp1(x)\n",
" x_inp2 = self.inp2(x)\n",
" x_inp3 = self.inp3(x) \n",
" \n",
" x = torch.cat((x_inp1, x_inp2, x_inp3), dim=1)\n",
" \n",
"# x = torch.transpose(x, 1, 2)\n",
" x = x.permute(0, 2, 1, 3)\n",
" x = torch.reshape(x, (-1, x.shape[1], x.shape[2]))\n",
" \n",
" x, _ = self.lstm(x, (h0, c0))\n",
" x = x[:, -1, :]\n",
" x = self.fc1(x)\n",
" forecast_y = torch.softmax(x, dim=1)\n",
" \n",
" return forecast_y"
],
"id": "6jRdgwGkBpjZ"
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"scrolled": true,
"id": "xv7Yw2liBpja",
"outputId": "615ed8cf-a06d-4a34-e15f-ad74d792b7d3",
"colab": {
"base_uri": "https://localhost:8080/"
}
},
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"deeplob(\n",
" (conv1): Sequential(\n",
" (0): Conv2d(1, 32, kernel_size=(1, 2), stride=(1, 2))\n",
" (1): LeakyReLU(negative_slope=0.01)\n",
" (2): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (3): Conv2d(32, 32, kernel_size=(4, 1), stride=(1, 1))\n",
" (4): LeakyReLU(negative_slope=0.01)\n",
" (5): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (6): Conv2d(32, 32, kernel_size=(4, 1), stride=(1, 1))\n",
" (7): LeakyReLU(negative_slope=0.01)\n",
" (8): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" )\n",
" (conv2): Sequential(\n",
" (0): Conv2d(32, 32, kernel_size=(1, 2), stride=(1, 2))\n",
" (1): Tanh()\n",
" (2): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (3): Conv2d(32, 32, kernel_size=(4, 1), stride=(1, 1))\n",
" (4): Tanh()\n",
" (5): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (6): Conv2d(32, 32, kernel_size=(4, 1), stride=(1, 1))\n",
" (7): Tanh()\n",
" (8): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" )\n",
" (conv3): Sequential(\n",
" (0): Conv2d(32, 32, kernel_size=(1, 10), stride=(1, 1))\n",
" (1): LeakyReLU(negative_slope=0.01)\n",
" (2): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (3): Conv2d(32, 32, kernel_size=(4, 1), stride=(1, 1))\n",
" (4): LeakyReLU(negative_slope=0.01)\n",
" (5): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (6): Conv2d(32, 32, kernel_size=(4, 1), stride=(1, 1))\n",
" (7): LeakyReLU(negative_slope=0.01)\n",
" (8): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" )\n",
" (inp1): Sequential(\n",
" (0): Conv2d(32, 64, kernel_size=(1, 1), stride=(1, 1), padding=same)\n",
" (1): LeakyReLU(negative_slope=0.01)\n",
" (2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (3): Conv2d(64, 64, kernel_size=(3, 1), stride=(1, 1), padding=same)\n",
" (4): LeakyReLU(negative_slope=0.01)\n",
" (5): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" )\n",
" (inp2): Sequential(\n",
" (0): Conv2d(32, 64, kernel_size=(1, 1), stride=(1, 1), padding=same)\n",
" (1): LeakyReLU(negative_slope=0.01)\n",
" (2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (3): Conv2d(64, 64, kernel_size=(5, 1), stride=(1, 1), padding=same)\n",
" (4): LeakyReLU(negative_slope=0.01)\n",
" (5): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" )\n",
" (inp3): Sequential(\n",
" (0): MaxPool2d(kernel_size=(3, 1), stride=(1, 1), padding=(1, 0), dilation=1, ceil_mode=False)\n",
" (1): Conv2d(32, 64, kernel_size=(1, 1), stride=(1, 1), padding=same)\n",
" (2): LeakyReLU(negative_slope=0.01)\n",
" (3): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" )\n",
" (lstm): LSTM(192, 64, batch_first=True)\n",
" (fc1): Linear(in_features=64, out_features=3, bias=True)\n",
")"
]
},
"metadata": {},
"execution_count": 12
}
],
"source": [
"model = deeplob(y_len = dataset_train.num_classes)\n",
"model.to(device)"
],
"id": "xv7Yw2liBpja"
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"scrolled": true,
"id": "xQI93wsNBpjb",
"outputId": "e88204bc-e9ee-4978-acea-6f9560ffac8a",
"colab": {
"base_uri": "https://localhost:8080/"
}
},
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"==========================================================================================\n",
"Layer (type:depth-idx) Output Shape Param #\n",
"==========================================================================================\n",
"deeplob -- --\n",
"├─Sequential: 1-1 [1, 32, 94, 20] --\n",
"│ └─Conv2d: 2-1 [1, 32, 100, 20] 96\n",
"│ └─LeakyReLU: 2-2 [1, 32, 100, 20] --\n",
"│ └─BatchNorm2d: 2-3 [1, 32, 100, 20] 64\n",
"│ └─Conv2d: 2-4 [1, 32, 97, 20] 4,128\n",
"│ └─LeakyReLU: 2-5 [1, 32, 97, 20] --\n",
"│ └─BatchNorm2d: 2-6 [1, 32, 97, 20] 64\n",
"│ └─Conv2d: 2-7 [1, 32, 94, 20] 4,128\n",
"│ └─LeakyReLU: 2-8 [1, 32, 94, 20] --\n",
"│ └─BatchNorm2d: 2-9 [1, 32, 94, 20] 64\n",
"├─Sequential: 1-2 [1, 32, 88, 10] --\n",
"│ └─Conv2d: 2-10 [1, 32, 94, 10] 2,080\n",
"│ └─Tanh: 2-11 [1, 32, 94, 10] --\n",
"│ └─BatchNorm2d: 2-12 [1, 32, 94, 10] 64\n",
"│ └─Conv2d: 2-13 [1, 32, 91, 10] 4,128\n",
"│ └─Tanh: 2-14 [1, 32, 91, 10] --\n",
"│ └─BatchNorm2d: 2-15 [1, 32, 91, 10] 64\n",
"│ └─Conv2d: 2-16 [1, 32, 88, 10] 4,128\n",
"│ └─Tanh: 2-17 [1, 32, 88, 10] --\n",
"│ └─BatchNorm2d: 2-18 [1, 32, 88, 10] 64\n",
"├─Sequential: 1-3 [1, 32, 82, 1] --\n",
"│ └─Conv2d: 2-19 [1, 32, 88, 1] 10,272\n",
"│ └─LeakyReLU: 2-20 [1, 32, 88, 1] --\n",
"│ └─BatchNorm2d: 2-21 [1, 32, 88, 1] 64\n",
"│ └─Conv2d: 2-22 [1, 32, 85, 1] 4,128\n",
"│ └─LeakyReLU: 2-23 [1, 32, 85, 1] --\n",
"│ └─BatchNorm2d: 2-24 [1, 32, 85, 1] 64\n",
"│ └─Conv2d: 2-25 [1, 32, 82, 1] 4,128\n",
"│ └─LeakyReLU: 2-26 [1, 32, 82, 1] --\n",
"│ └─BatchNorm2d: 2-27 [1, 32, 82, 1] 64\n",
"├─Sequential: 1-4 [1, 64, 82, 1] --\n",
"│ └─Conv2d: 2-28 [1, 64, 82, 1] 2,112\n",
"│ └─LeakyReLU: 2-29 [1, 64, 82, 1] --\n",
"│ └─BatchNorm2d: 2-30 [1, 64, 82, 1] 128\n",
"│ └─Conv2d: 2-31 [1, 64, 82, 1] 12,352\n",
"│ └─LeakyReLU: 2-32 [1, 64, 82, 1] --\n",
"│ └─BatchNorm2d: 2-33 [1, 64, 82, 1] 128\n",
"├─Sequential: 1-5 [1, 64, 82, 1] --\n",
"│ └─Conv2d: 2-34 [1, 64, 82, 1] 2,112\n",
"│ └─LeakyReLU: 2-35 [1, 64, 82, 1] --\n",
"│ └─BatchNorm2d: 2-36 [1, 64, 82, 1] 128\n",
"│ └─Conv2d: 2-37 [1, 64, 82, 1] 20,544\n",
"│ └─LeakyReLU: 2-38 [1, 64, 82, 1] --\n",
"│ └─BatchNorm2d: 2-39 [1, 64, 82, 1] 128\n",
"├─Sequential: 1-6 [1, 64, 82, 1] --\n",
"│ └─MaxPool2d: 2-40 [1, 32, 82, 1] --\n",
"│ └─Conv2d: 2-41 [1, 64, 82, 1] 2,112\n",
"│ └─LeakyReLU: 2-42 [1, 64, 82, 1] --\n",
"│ └─BatchNorm2d: 2-43 [1, 64, 82, 1] 128\n",
"├─LSTM: 1-7 [1, 82, 64] 66,048\n",
"├─Linear: 1-8 [1, 3] 195\n",
"==========================================================================================\n",
"Total params: 143,907\n",
"Trainable params: 143,907\n",
"Non-trainable params: 0\n",
"Total mult-adds (M): 35.53\n",
"==========================================================================================\n",
"Input size (MB): 0.02\n",
"Forward/backward pass size (MB): 4.97\n",
"Params size (MB): 0.58\n",
"Estimated Total Size (MB): 5.56\n",
"=========================================================================================="
]
},
"metadata": {},
"execution_count": 13
}
],
"source": [
"summary(model, (1, 1, 100, 40))"
],
"id": "xQI93wsNBpjb"
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "yY1-9SksBpjb"
},
"outputs": [],
"source": [
"criterion = nn.CrossEntropyLoss()\n",
"optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)"
],
"id": "yY1-9SksBpjb"
},
{
"cell_type": "markdown",
"metadata": {
"id": "eGEpXgZXBpjb"
},
"source": [
"# Model Training\n"
],
"id": "eGEpXgZXBpjb"
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "6BjXocaRBpjc"
},
"outputs": [],
"source": [
"# A function to encapsulate the training loop\n",
"def batch_gd(model, criterion, optimizer, train_loader, test_loader, epochs):\n",
" \n",
" train_losses = np.zeros(epochs)\n",
" test_losses = np.zeros(epochs)\n",
" best_test_loss = np.inf\n",
" best_test_epoch = 0\n",
"\n",
" for it in tqdm(range(epochs)):\n",
" \n",
" model.train()\n",
" t0 = datetime.now()\n",
" train_loss = []\n",
" for inputs, targets in train_loader:\n",
" # move data to GPU\n",
" inputs, targets = inputs.to(device, dtype=torch.float), targets.to(device, dtype=torch.int64)\n",
" # print(\"inputs.shape:\", inputs.shape)\n",
" # zero the parameter gradients\n",
" optimizer.zero_grad()\n",
" # Forward pass\n",
" # print(\"about to get model output\")\n",
" outputs = model(inputs)\n",
" # print(\"done getting model output\")\n",
" # print(\"outputs.shape:\", outputs.shape, \"targets.shape:\", targets.shape)\n",
" loss = criterion(outputs, targets)\n",
" # Backward and optimize\n",
" # print(\"about to optimize\")\n",
" loss.backward()\n",
" optimizer.step()\n",
" train_loss.append(loss.item())\n",
" # Get train loss and test loss\n",
" train_loss = np.mean(train_loss) # a little misleading\n",
" \n",
" model.eval()\n",
" test_loss = []\n",
" for inputs, targets in test_loader:\n",
" inputs, targets = inputs.to(device, dtype=torch.float), targets.to(device, dtype=torch.int64) \n",
" outputs = model(inputs)\n",
" loss = criterion(outputs, targets)\n",
" test_loss.append(loss.item())\n",
" test_loss = np.mean(test_loss)\n",
"\n",
" # Save losses\n",
" train_losses[it] = train_loss\n",
" test_losses[it] = test_loss\n",
" \n",
" if test_loss < best_test_loss:\n",
" torch.save(model, './best_val_model_pytorch')\n",
" best_test_loss = test_loss\n",
" best_test_epoch = it\n",
" print('model saved')\n",
"\n",
" dt = datetime.now() - t0\n",
" print(f'Epoch {it+1}/{epochs}, Train Loss: {train_loss:.4f}, \\\n",
" Validation Loss: {test_loss:.4f}, Duration: {dt}, Best Val Epoch: {best_test_epoch}')\n",
"\n",
" return train_losses, test_losses"
],
"id": "6BjXocaRBpjc"
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"scrolled": true,
"id": "Lnon9v-ZBpjc",
"outputId": "248435b7-3e0a-478c-b553-3d7ef4572607",
"colab": {
"base_uri": "https://localhost:8080/"
}
},
"outputs": [
{
"output_type": "stream",
"name": "stderr",
"text": [
"100%|██████████| 1/1 [01:13<00:00, 73.68s/it]"
]
},
{
"output_type": "stream",
"name": "stdout",
"text": [
"model saved\n",
"Epoch 1/1, Train Loss: 0.9171, Validation Loss: 1.0344, Duration: 0:01:13.677312, Best Val Epoch: 0\n"
]
},
{
"output_type": "stream",
"name": "stderr",
"text": [
"\n"
]
}
],
"source": [
"train_losses, val_losses = batch_gd(model, criterion, optimizer, \n",
" train_loader, val_loader, epochs=1) # use 50 to replicate paper"
],
"id": "Lnon9v-ZBpjc"
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "I5ZlWAm3Bpjd",
"outputId": "91a1fc78-fe59-414c-ef7c-c821f451383c"
},
"outputs": [
{
"data": {
"text/plain": [
"<matplotlib.legend.Legend at 0x7f968ce4b990>"
]
},
"execution_count": 18,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAA2oAAAFpCAYAAADtINuMAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8rg+JYAAAACXBIWXMAAAsTAAALEwEAmpwYAABQn0lEQVR4nO3dd5zdVZ3/8fe5ZeZO7zPpjbTJTBqZFAiQUKQKCCqgsMLqiiIuq+iurKti+eHiLiriooiuiA0WQRSkqRCaQkgC6YX0zKRM7/WW7++Pc6eFJDOTuW0mr+fjMY/v937v997vyeQm977vOedzjOM4AgAAAAAkDle8GwAAAAAA6I+gBgAAAAAJhqAGAAAAAAmGoAYAAAAACYagBgAAAAAJhqAGAAAAAAlmwKBmjPm5MabKGLP5OPfPNsa8YYzpNMZ8MfJNBAAAAIBTy2B61H4h6eIT3F8n6TZJ90SiQQAAAABwqhswqDmO86psGDve/VWO46yR5I9kwwAAAADgVMUcNQAAAABIMJ5YXswYc7OkmyUpLS1t0ezZs2N5eQAAAABIGOvWratxHKfgWPfFNKg5jvOgpAclqayszFm7dm0sLw8AAAAACcMYs/949zH0EQAAAAASzIA9asaYRyStlJRvjKmQdKckryQ5jvOAMWaMpLWSMiWFjDGfkzTHcZymaDUaAAAAAEazAYOa4zgfGeD+I5ImRKxFAAAAAHCKi+kcNQAAAACR4ff7VVFRoY6Ojng3BQPw+XyaMGGCvF7voB9DUAMAAABGoIqKCmVkZGjKlCkyxsS7OTgOx3FUW1uriooKTZ06ddCPo5gIAAAAMAJ1dHQoLy+PkJbgjDHKy8sbcs8nQQ0AAAAYoQhpI8PJ/D0R1AAAAAAMWUNDg370ox+d1GMvvfRSNTQ0DPr8r3/967rnnntO6lojFUENAAAAwJCdKKgFAoETPvbZZ59VdnZ2FFo1ehDUAAAAAAzZHXfcod27d2vBggX613/9V7388ss6++yzdcUVV2jOnDmSpA984ANatGiRSkpK9OCDD/Y8dsqUKaqpqdG+fftUXFysT37ykyopKdGFF16o9vb2E153/fr1WrZsmebNm6errrpK9fX1kqT77rtPc+bM0bx583TddddJkl555RUtWLBACxYs0MKFC9Xc3Byl30bkUfURAAAAGOG+8fQWbT3UFNHnnDMuU3deXnLc+++++25t3rxZ69evlyS9/PLLevvtt7V58+ae6oY///nPlZubq/b2di1evFgf/OAHlZeX1+95du7cqUceeUQ//elPdc011+iJJ57QDTfccNzrfuxjH9MPf/hDrVixQl/72tf0jW98Q/fee6/uvvtu7d27V8nJyT3DKu+55x7df//9Wr58uVpaWuTz+Yb3S4khetT6qt8v7Xg+3q0AAAAARqQlS5b0K0F/3333af78+Vq2bJnKy8u1c+fO9zxm6tSpWrBggSRp0aJF2rdv33Gfv7GxUQ0NDVqxYoUk6cYbb9Srr74qSZo3b56uv/56/frXv5bHY/ujli9frttvv1333XefGhoaeo6PBCOnpbHwzq+k174nffmQ5B05aRsAAACnthP1fMVSWlpaz/7LL7+sv/71r3rjjTeUmpqqlStXHrNEfXJycs++2+0ecOjj8TzzzDN69dVX9fTTT+uuu+7Spk2bdMcdd+iyyy7Ts88+q+XLl+uFF17Q7NmzT+r5Y40etb6KSiQnKFVvj3dLAAAAgISWkZFxwjlfjY2NysnJUWpqqrZv364333xz2NfMyspSTk6OXnvtNUnSr371K61YsUKhUEjl5eU699xz9Z3vfEeNjY1qaWnR7t27NXfuXH3pS1/S4sWLtX37yPmcT49aX0Wldlu5RRq3IK5NAQAAABJZXl6eli9frtLSUl1yySW67LLL+t1/8cUX64EHHlBxcbFmzZqlZcuWReS6Dz/8sD796U+rra1N06ZN00MPPaRgMKgbbrhBjY2NchxHt912m7Kzs/XVr35Vq1atksvlUklJiS655JKItCEWjOM4cblwWVmZs3bt2rhc+7hCQenb46Wyj0sXfzverQEAAACOa9u2bSouLo53MzBIx/r7Msascxyn7FjnM/SxL5dbKiyWKjfHuyUAAAAATmEEtaMVldigFqeeRgAAAAAgqB2tqFRqq5VaKuPdEgAAAACnKILa0YrCpU0Z/ggAAAAgTghqR+sJalvi2w4AAAAApyyC2tFSc6XM8QQ1AAAAAHFDUDuWohLpCEMfAQAAgEhKT0+XJB06dEgf+tCHjnnOypUrNdAyXvfee6/a2tp6bl966aVqaGgYdvu+/vWv65577hn280QCQe1Yikqkmh1SoCveLQEAAABGnXHjxunxxx8/6ccfHdSeffZZZWdnR6BliYOgdixFpVIoINW8G++WAAAAAAnpjjvu0P33399zu7s3qqWlReeff75OP/10zZ07V3/84x/f89h9+/aptLRUktTe3q7rrrtOxcXFuuqqq9Te3t5z3i233KKysjKVlJTozjvvlCTdd999OnTokM4991yde+65kqQpU6aopqZGkvS9731PpaWlKi0t1b333ttzveLiYn3yk59USUmJLrzwwn7XOZb169dr2bJlmjdvnq666irV19f3XH/OnDmaN2+errvuOknSK6+8ogULFmjBggVauHChmpubT+ZX2o9n2M8wGhXZF40qt0hjSuPbFgAAAGAgz90hHdkU2eccM1e65O7j3n3ttdfqc5/7nG699VZJ0mOPPaYXXnhBPp9PTz75pDIzM1VTU6Nly5bpiiuukDHmmM/z4x//WKmpqdq2bZs2btyo008/vee+u+66S7m5uQoGgzr//PO1ceNG3Xbbbfre976nVatWKT8/v99zrVu3Tg899JBWr14tx3G0dOlSrVixQjk5Odq5c6ceeeQR/fSnP9U111yjJ554QjfccMNx/3wf+9jH9MMf/lArVqzQ1772NX3jG9/Qvffeq7vvvlt79+5VcnJyz3DLe+65R/fff7+WL1+ulpYW+Xy+wf6Wj4setWPJmy65kyjRDwAAABzHwoULVVVVpUOHDmnDhg3KycnRxIkT5TiOvvzlL2vevHm64IILdPDgQVVWHn+N4ldffbUnMM2bN0/z5s3rue+xxx7T6aefroULF2rLli3aunXrCdv0+uuv66qrrlJaWprS09N19dVX67XXXpMkTZ06VQsWLJAkLVq0SPv27Tvu8zQ2NqqhoUErVqyQJN1444169dVXe9p4/fXX69e//rU8HtvvtXz5ct1+++2677771NDQ0HN8OOhROxa3RyqYTVADAADAyHCCnq9o+vCHP6zHH39cR44c0bXXXitJ+s1vfqPq6mqtW7dOXq9XU6ZMUUdHx5Cfe+/evbrnnnu0Zs0a5eTk6Kabbjqp5+mWnJzcs+92uwcc+ng8zzzzjF599VU9/fTTuuuuu7Rp0ybdcccduuyyy/Tss89q+fLleuGFFzR79uyTbqtEj9rxFZVSoh8AAAA4gWuvvVaPPvqoHn/8cX34wx+WZHujCgsL5fV6tWrVKu3fv/+Ez3HOOefot7/9rSRp8+bN2rhxoySpqalJaWlpysrKUmVlpZ577rmex2RkZBxzHtjZZ5+tP/zhD2pra1Nra6uefPJJnX322UP+c2VlZSknJ6enN+5Xv/qVVqxYoVAopPLycp177rn6zne+o8bGRrW0tGj37t2aO3euvvSlL2nx4sXavn37kK95NHrUjmdMqbTht1JLtZReEO/WAAAAAAmnpKREzc3NGj9+vMaOHStJuv7663X55Zdr7ty5KisrG7Bn6ZZbbtE//uM/qri4WMXFxVq0aJEkaf78+Vq4cKFmz56tiRMnavny5T2Pufnmm3XxxRdr3LhxWrVqVc/x008/XTfddJOWLFkiSfqnf/onLVy48ITDHI/n4Ycf1qc//Wm1tbVp2rRpeuihhxQMBnXDDTeosbFRjuPotttuU3Z2tr761a9q1apVcrlcKikp0SWXXDLk6x3NOI4z7Cc5GWVlZc5A6yPE1Z6XpV9eKX3sj9K0lfFuDQAAANDPtm3bVFxcHO9mYJCO9fdljFnnOE7Zsc5n6OPx9K38CAAAAAAxRFA7nrR8Kb1IOkJBEQAAAACxRVA7kaISKj8CAAAAiDmC2okUlUrV26VgIN4tAQAAAN4jXvUmMDQn8/dEUDuRolIp2CXV7op3SwAAAIB+fD6famtrCWsJznEc1dbWyufzDelxlOc/kaISu63cLBUOb8E6AAAAIJImTJigiooKVVdXx7spGIDP59OECROG9BiC2onkz5RcHhvU5n4o3q0BAAAAeni9Xk2dOjXezUCUMPTxRDxJUv4sSvQDAAAAiCmC2kDGlBLUAAAAAMQUQW0gRSVS00GprS7eLQEAAABwiiCoDaSnoAi9agAAAABig6A2kKJSuyWoAQAAAIgRgtpA0ouk1Hxb+REAAAAAYoCgNhBj7PBHetQAAAAAxAhBbTCKSqWqbVIoGO+WAAAAADgFENQGo6hECrRLdXvi3RIAAAAApwCC2mD0VH5knhoAAACA6COoDUbBbMm4macGAAAAICYIaoPh9Un5MwhqAAAAAGKCoDZYRSUMfQQAAAAQEwS1wSoqkRoOSB2N8W4JAAAAgFGOoDZYRaV2W7k1vu0AAAAAMOoR1AarJ6gx/BEAAABAdBHUBitznOTLpqAIAAAAgKgjqA2WMbZXjR41AAAAAFFGUBuKohI7Ry0UindLAAAAAIxiBLWhGFMq+Vulhn3xbgkAAACAUYygNhRFJXbLPDUAAAAAUURQG4qCYkmGoAYAAAAgqghqQ5GUKuWdJh3ZFO+WAAAAABjFCGpDVVRCjxoAAACAqCKoDVXRXKl+r9TZEu+WAAAAABilCGpD1V1QpGpbfNsBAAAAYNQaMKgZY35ujKkyxhxzpWdj3WeM2WWM2WiMOT3yzUwgPZUfWfgaAAAAQHQMpkftF5IuPsH9l0iaEf65WdKPh9+sBJY9SUrKIKgBAAAAiJoBg5rjOK9KqjvBKVdK+qVjvSkp2xgzNlINTDjGUFAEAAAAQFRFYo7aeEnlfW5XhI+9hzHmZmPMWmPM2urq6ghcOk7GlNqg5jjxbgkAAACAUSimxUQcx3nQcZwyx3HKCgoKYnnpyCoqkTqbpMbygc8FAAAAgCGKRFA7KGlin9sTwsdGr6JSuz3CPDUAAAAAkReJoPaUpI+Fqz8uk9ToOM7hCDxv4iostlvmqQEAAACIAs9AJxhjHpG0UlK+MaZC0p2SvJLkOM4Dkp6VdKmkXZLaJP1jtBqbMJIzpJwpVH4EAAAAEBUDBjXHcT4ywP2OpFsj1qKRoqiUHjUAAAAAURHTYiKjSlGpVLdb6mqLd0sAAAAAjDIEtZNVVCI5Ial6W7xbAgAAAGCUIaidrKISu2X4IwAAAIAII6idrJypkjeNoAYAAAAg4ghqJ8vlkormENQAAAAARBxBbTiKSmyJfseJd0sAAAAAjCIEteEoKpXa66WmQ/FuCQAAAIBRhKA2HBQUAQAAABAFBLXh6Alqm+PbDgAAAACjCkFtOHxZUtYketQAAAAARBRBbbiKSghqAAAAACKKoDZcRSVSzbuSvyPeLQEAAAAwShDUhquoRHKCUs2OeLcEAAAAwChBUBuuMXPtluGPAAAAACKEoDZcudMkj4+gBgAAACBiCGrD5XJLhcXSkU3xbgkAAACAUYKgFglFJXYtNceJd0sAAAAAjAIEtUgoKpXaaqWWqni3BAAAAMAoQFCLhKJSu63cHN92AAAAABgVCGqRUFRitxQUAQAAABABBLVISM2VMsbRowYAAAAgIghqkVJUQo8aAAAAgIggqEXKmFKpeocU6Ip3SwAAAACMcAS1SCkqlUJ+qXZnvFsCAAAAYIQjqEUKBUUAAAAARAhBLVLypkveVGnfa/FuCQAAAIARjqAWKW6vNPdD0sbfSe318W4NAAAAgBGMoBZJSz4lBdqlt38V75YAAAAAGMEIapE0plSafJa05qdSKBjv1gAAAAAYoQhqkbb0ZqnhgPTu8/FuCQAAAIARiqAWabMukzInSKsfiHdLAAAAAIxQBLVIc3ukJf8k7X1VqtoW79YAAAAAGIEIatFw+o2Sxyet/km8WwIAAABgBCKoRUNqrjT3w9LG/6NUPwAAAIAhI6hFy9JPSf426Z1fx7slAAAAAEYYglq0jJkrTV4uvUWpfgAAAABDQ1CLpiU3Sw37pXdfiHdLAAAAAIwgBLVomv1+KXO89BZFRQAAAAAMHkEtmtweafEnpD0vS1Xb490aAAAAACMEQS3aTr9JcidLbz0Y75YAAAAAGCEIatGWlmdL9W94RGpviHdrAAAAAIwABLVYWHozpfoBAAAADBpBLRbGzpcmnSGtoVQ/AAAAgIER1GJl6aek+n3Szj/HuyUAAAAAEhxBLVa6S/WvplQ/AAAAgBMjqMWK2yuVfVzas0qq3hHv1gAAAABIYAS1WFp0E6X6AQAAAAyIoBZLafnS3A9J6x+ROhrj3ZrEsO9v0v9eJDWUx7slAAAAQMIgqMXakpslf6v0zm/i3ZL462yWnvy0VP6m9OI34t0aAAAAIGEQ1GJt3AJp4jI7/DEUindr4usvd0qN5dKsS6VNv5Mq1sa7RQAAAEBCIKj1caC2Tc9sPBz9Cy39lFS/V9r1l+hfK1HteUVa+7/SGbdKVz8opRVKL3xZcpx4twwAAACIO4JaH79bV67bHn1HLZ2B6F6o+HIpY5y0+oHoXidRdbZIT31WypsunfcVKTnDbstXS1v/EO/WAQAAAHFHUOtj6dQ8BUOO1u6ri+6F3F5p8cel3S9J1e9G91qJ6K932uIhV94veVPssYU3SIUldjikvyO+7QMAAADijKDWx+mTs+VxGa3eG+WgJkmn3yS5k069Uv17X5XW/Exa9hlp0rLe4y63dNFdUsN+6S0WBQcAAMCpjaDWR2qSR3MnZOmtWAS19AKp9EPS+t+eOqX6O1ukP35Wyj3NDnU82mnnSjMukl69R2qtiX37AAAAgARBUDvK0ql52ljRoPauYAwuFi7Vv/630b9WInjxG1LDATvkMSn12Odc+C2pq1V6+T9j2zYAAAAggRDUjrJ0Wq78QUdvH6iP/sXGLZQmLj01SvXvfc3+OZfdIk0+4/jnFcySyj4urX1Iqt4Ru/YBAAAACYSgdpSyyTlyGWn1ntrYXHDpp6S6PdKuv8bmevHQ1Sr98VYpZ6p03lcHPn/lHVJSuvTnQZwLAAAAjEIEtaNk+LwqGZelN2MxT02Siq+QMsaO7lL9fw0PefzAj44/5LGvtHzpnC9KO1+wlTEBAACAUwxB7RiWTs3V+vIGdfhjME/N7ZXKPiHtflGq2Rn968XavtdtFceln5Imnzn4xy39lJQ9WXrhK1IoBn8PAAAAQALxDOYkY8zFkn4gyS3pZ47j3H3U/ZMl/VxSgaQ6STc4jlMR4bbGzNJpefrZ63u1obxBS6flRf+Ci26SXv0vafVPpIu+LQU6+vx0HrUN7/vb+98X7JJmXSoVzIx+ewer75DH8782tMd6kqX3fVP63Y3SO7+WFt0YnTYCAAAACWjAoGaMcUu6X9L7JFVIWmOMecpxnK19TrtH0i8dx3nYGHOepP+U9A/RaHAsLJmSK2Ok1XvrYhPU0guk0g9Ka35qf07W69+XbnxKGjs/cm0bjhe/KdXvk256VkpKG/rj51wpTVwmvfT/pNKrpeSMiDcRAAAASESD6VFbImmX4zh7JMkY86ikKyX1DWpzJN0e3l8l6Q8RbGPMZaV6NasoQ6v31kqaEZuLnvdVKWeK5PJIHp/tUfL4JG9K736/7VHHW2ukX31A+uWV0seeksbOi027j2ff3+y8uyWfkqYsP7nnMMb2MP7sPOn1e6XzKS4CAACAU8Nggtp4SeV9bldIWnrUORskXS07PPIqSRnGmDzHcWJUOjHylk3L06NrDqgrEFKSJwZT+bLG22qHJyslW7rxaekX75d+eYXdHzM3Ys0bkq628JDHKdIFdw7vuSYskuZeI73xP3aIaPbESLQQAAAASGiRSiBflLTCGPOOpBWSDkp6TwUIY8zNxpi1xpi11dXVEbp0dCydmqsOf0ibDjbGuymDlztVuulpyZsqPXyFdGRzfNrx4jel+r3SFf9zckMej9Y9v+3Fbw7/uQAAAIARYDBB7aCkvt0YE8LHejiOc8hxnKsdx1ko6T/CxxqOfiLHcR50HKfMcZyygoKCk291DCyZmitJ4eGPI0juNNub5vHZnrXKLbG9/v6/h4c83ixNPTsyz5k9UTrjVmnTY1LFusg8JwAAAJDAjOM4Jz7BGI+kdyWdLxvQ1kj6qOM4W/qcky+pznGckDHmLklBx3FOWOavrKzMWbt27XDbH1UXfO8Vjc9O0cMfXxLvpgxd7W7pF5fZapA3/kkqmhP9a3a1SQ8st+X0b/m7lJweuefubJbuWyjlniZ9/Hk7fw0AAGCo2uqkV/5LqttjR/4kpUlJ6eFtap/9tOPve9MkF6tcYfiMMescxyk71n0DzlFzHCdgjPmspBdky/P/3HGcLcaYb0pa6zjOU5JWSvpPY4wj6VVJt0as9XG0dGqu/rj+kALBkDzuEfaPMe806aZnpIculR6+XLrpT1JhcXSv+dL/s//p3fh0ZEOaZCs+nvcV6el/kbb+USr5QGSfHwAAjG6hkLT+19Jf7pQ6GqWiErvcUVdr+KdZckKDey53kjR5uTTzIvuTOy26bccpacAetWgZCT1qT204pNseeUdPfXa55k3IjndzTk7NLtuz5gRtz1rh7OhcZ/8b0kOXSIs/IV323ehcIxSUHjhb6mqRPrvGVrsEAAAYyOGN0jNfkCrekiadYT+rFJX0P8dx7Nq0Xa32s0ZPgDvGfvNhaddfpZp37WPzZ4ZD28XSxKWS2xv7PyNGpBP1qBHUTqCyqUNLv/2i/uPSYn3ynBH8TUnNznBYc2zPWsGsyD5/V5v0wFlSyC/d8kbke9P62v2S9KurpPd9S1p+W/SuA6BXKMQQHyBaggH7od+bKqXmMrQ/0joapVXflt56UErJlS78f9L86yL3e67bI737Z+nd56V9r9vPQr4safoFNrRNv8D+vQLHQVAbhpX/vUrTC9P1sxsXx7spw1P9rg1rkh0SWTAzMs/beNBWY9z4qF2/bdqKyDzvifzmw9KBN6Xb3pHS8qN/PeBU1VAu/fk/pF0vSZfdYz/cABiaQJfUdFBqOGB/Gst79xvK7X1OuFC2O0lKHyNldP+MlTLH2m337YwxUnImgW4gjiNtfkJ64ctSS5Ud8XPeV6SUnOhds7NZ2r1KevcFaeefpdYqybikCUt6e9sKi4f3dxfokoKddr5cIr8G/O1S9Q6paptUtVWq3m73g35p3EJp/CJp/EJp3OmnfJAlqA3Dlx7fqOc2H9b6r10olyuB/0EMRvUOu86aMTas5Z/kYt6dLdL2P0kbHpH2vCLJkc68TbrwWxFt7nFVbZd+fKZU9nH74RFAZPk7pL//UHotPIw5b7pUucmuZXjxdySvL67NAxJOZ7NUsfbYYazpkKS+n7WMlDlOyp5kf7ImSlkTpECH7VlrPtJne0TqbHrv9byp/YNb3nRp4Q32+WA/7zzzBWnfazYIXPZdafzpsW1DKCQdfseGtneflw5vsMezJtnQljdd8reG58i12f2uNsnfZodW+tuOOt5u90MB+zyeFPs66vczvnc/Y5yUVhD90RCBLql2V/8wVrVVqturnte9O8kODS0stsH14NtS7c7e58iZGg5up9vtmHm2qMspgqA2DL9/u0K3P7ZBz952tuaMy4x3c4avarv08Psl4w6HtemDe1woZP/D2/CItPUp+59F9mRp/kekedfY4iWx9MwXpLUPSZ95I/JDOXFsjmM/cBzZaL+RnLQs3i0aGsexQ2fHzqcn9ngcR9rxnPTCv0v1+6Q5V9phQhnjpJe+Jf3tXvsGes0v7bqNwKksFJT2viJteFTa9rT9YC3Z99fM8b1BLHti/1CWOV7yJA3+Op0tUkvlMUJcn239Pntu8eXSss/YOVKJ3NsSDEgH/m7/v3F5pAmLpQllNmAMR1erreb4xv32g/75d9ovmFzuiDR7WJoO2V62d/8s7VnV+3qRCVeRTLVt9qZJ3pTe/aTU8H3hc7ypdv5ba7XtjW06bJ+7+VBviOvm8oZ7ZY8KdL4se11jTrDVsY+HArayeNVWG8pqd/Ve17jt58HCYqlwjlQw225zp0nuo+oXdjRKh9ZLB9dJh9624a3pYO/zFBbb4DYuHN4Ki0ftvD+C2jAcbGjX8rtf0p2Xz9E/Lh8lH0yqttmeNbfXhrUThazqHfZNaONjUlOFHW5R8gEb0CadEb83gtYaW65/0hnS9Y/Fpw2jWSgk1e223wAe2WgnYR/ZKLX1WVdw0U3SRd+OzKLm0Va9Q/rT56X9f7Mh86Jv29dwIn+QibWaXdLzd0i7/iLlz5Iu/S9p2sr+5+x4TnryU/ZL0qt+LM2+LB4tBeKr+l1pw2/D74sH7YfekqvtFxt5p9kvNo7+UBptjRXSWz+V1v1C6miwQ8uWfUaa84GhhcJoCnRJe1+Vtv1R2v6MfT9xJ9sqiyG/PSdjnA1sE8qk8WXSuAWDe49xHPucz99hezMXXC9d8A0pPUHX7A102gCelGrXvY3Ee1Eo1BvemsPhrelgeNvnJ9A+/GvJSDmTbQgrLJYKiu02f8bwCr01V4ZD2zob3A6us69nyf6exsy1vc+peVJqvv3SNTXX7qfm2dspuYnzmh8kgtownfWdl1Q6LksP/MOieDclciq32rL9xwprrbV2XPeGR+w/GOOWpp9v56fMutR+05MI/vYD6S9fk274vW0fTk6g04b3voHsyGbbayrZb+QKi6Wx86Qx8+1/lO8+b3//udOkq38qTUjQfxv+Djt87/Xv2zf7c74obfuTVP6mNPUc6f33xr43ONF0tkiv/rf9BtqbIq28wy5Yf7xvLuv3SY/dKB1eb4c8n/+1Ufst5yknFLSFEWretWtwDttA39gfY+ty2dEa2ZNjH3ZOpK2u933x4Lo+74sfCb8vJshw4K5W++Xqmz+2Q8vSx0iL/0kq+8f4jCTwt0u7XpS2PSXteF7qbLRzq2ZeJBVfIc14n/1dHtkkHVwrVawJDyHdbx9v3HYd2PFlvb1ueTP6D+er2ys992+2t6qwxA5znHxG7P+sI4HjSO31dqiuHHu7Z6ujbh9na1w2pMXiS1rHker3hkPb2/bL45ZKG/Lb69V/SHEfyVk2wKXl94a61Fzbm7jslui3e4gIasN0+2Pr9fKOaq37ygUyo+kb+Mot4bCWLP3Dk/bNecOj0s4XbDf2mLn2Taj0Q1JGUbxb+16BTun+JfYborJPSGd/IXbfnoVCtuehq0WacVF0K11GWludtOX34f/0Ntox5d3fZial27/3MfPCwWyeHbpwrG+n9r0uPflp+/tfeYd01u2J9cFqz8vSn263PYPzrpUuvMu+PkIh6e1f2HV0gl3Sin+zgeNUCxvdE+3//FU7ZGbB9XaY0GD+rfs77AT9tf8rTTpT+tDPbcGDROE49s28apv9MDF2wcj5htVx7Lfg9ftt72/3MKVIv/e0VEmVm+2XdlVb7X71DjtPKhG4k+wcnvyZdnh7wSzb05s3PXahKOi35dfX/9Z+ORXsskFgwUekudck5vtit1DIDvV+80fS7hft+/y8a+yH1KNL0kdaZ7MNTVufknb+xX7p58u2PfDFl0vTzh3477ClOhzc1trtwbd75+olZ9khcRPKbG/c3//H/v997pdP/CUTRpdgwPa2tdZIbTU2vLXW2M84/W7X9u6n5Uu3b413y9+DoDZMj60p1789sVF/+fw5mlGUEe/mRNaRzTastdfZ2+lF9j/zeddJY0rj27bBaDokvfyf0ju/tuO2z/isdMatki9K8wmDAfvh9vXv2YAj2Qm9My+0Q19mXJi4E2APrpPe+pltf7DTTjLuG8jGzrcTeocy8bi9QXr2X6VNj9mqVlf/JP6LfrbWSC/8h61EmjtNuux70mnnvve8psPS81+yC6gXzpEuv0+aOMKruw7Wkc32G+j9f7Mh5tL/liYuGfrzbPyd9PRtNgx98H9jU/X1aG114Qns4fkS3RPZ2+t7z/H47ByHScvscOkJi6WU7Ni3ta/uQFa1XareFv4zbLdhqau5/7netKOKBYw9qnDAePut8bHCXFebff7KrfbLuaotdr+tpvec9CL7b6CopHdeyXD/HxvMN/Nywl+I97kd9Nte25oddohh9Xbbu9K9CLEJ97gVzOoT4mbbIVe+rOG1udvhjbbnbNPv7FCy1Hxp7odtQBszb+QNma7aLq1+wH4RG2iXpq6wgW3GRZErNNFeb4dGb33KBsTu95jZ75fmXCFNOXt4ASoUsl8m9/S6rbOvZSdk33svumv489swujmO/SIqUUaF9UFQG6b9ta1a8d8v61sfKNU/LJsc7+ZEXuUWae3PpZmX2DkpidQrMljV79piB9uesh9Yzv6iLcUbqUWxA532W9W/3Ws/RBTOsT14GWOlLU9KW/9g39C9adKsS6TSD9phMfFelNvfLm3+vbTmZ3YYqzfNDmFd/An7Z4jUB45Nj9veKycoXfId2zsT6w8zoZC0/td2OGxni3TW5+zf0UD/KW9/Vnr2izb0L/mkdN5XIx/0gwE7ZCMpzfYIxOvfWHu9XU9ozc/sN9znf006/WPDm2hftV167GN2mNW5X5bO+kJ0qox1tfapKLatN5g1H+49JzkzPIm9z0T2jka7nMeBN+zfgROUZKSi0nBwC4e3rPGRb7M0+ECWVmhDR2GxbXfuVPtFSM/ckj7zTpoPv7dogDspXDQgHN6Cnfb/9r6V17yp9rmLSnpDWVFJ4hfX8XfYggXd4a1mh/391e7qP0QzY6ydv+JOtv/GXF4bDlweu3Un9e4f675QwM5zqtxs7591sTT/o3Z43mjopWmrk95+WFr9oO1Fz50mLf20tOCjUnKfL6Edxxa66GwO/zTZ/1N7bncfC+/X7bHFxkIBKXOC7TWbc4UtaBLNIh5drfaLuZxR+LkMpxSC2jA5jqMz/vMllU3J0f98NMblXTE0B9dJf/2GrcSVNdF+cJx37cm/WXS12snZf/+h/XA0fpENgTMv7v9hNBiQ9r9uQ9G2p+wH4uQsO9Sj9IO2pyGWb/T1+6Q1/yu98yvblvxZdp7C/Oui19vYUC794Rb7hl18ue2hitXaKFXbbbGQA3+3Q/Euv3do1UA7m6WX/p+0+if2w95l9wy/UEZns/1mecdztjxzd6+1O1kqnC0VzbW91kUlNjRE83cVCtrXwovftK+Hsk/YfxuRumZni/T0v0ibH5emv0+6+sHhPXdni/23XPGWHfJUtbW3qp1ke8gKZvVOZO/eZo4/8RcEXa12KFV3cKtYY4cvS7Zkdt/gVjD72IEz6Lfhr6PRDrtpb+jd7z7efazhwMCBrGC23R/K7ysU7FPx7VC46tvB/qHO5e7fS1ZUMvQe80QXDNjetuodvSGu6aANDEG/HdIdDIS3XX32+9wX7Ood+i3ZCnMLPmr/3x6tazsF/fZ96s0f238DyZk24PYNYN09mCfi8tr3k7RCO6qk+Eo7JHGk9TgCcUZQi4DbHnlHb+yp1VtfPn90zVMbrXavkv76dVvwoGC27TmYdeng30DaG2wFrTd/ZD9gTznb9s5MWznwcwT9dn25zU/Yb2c7G20VouLLpdKr7XNF41vGUMjOp1jzUzsvwLhs2FjySXvNWLxuQyHpjR9KL37L9mx+4EfRLfTibw8XC7nXzhN837dsb97JfhitWGeH8lVutn9fl/zX0IbTNB2ywWzHs7a6WbDL9lzNvMgOiw0F7HMf2Wy3rdW9j80Y1z+4FZUOrvct0GnnG7VU2cVVe7bVdp5Wa3Xvuk6TzrTVHMfMPalfzwk5jp2z9vy/26F0H/6FnUMymMfV75PK37LBrHy17Qnq/qCYN8P+XvqGspwpkfk3FAz/fRx4I/zzpv2dSXYY3dj5/YNZe0NvkZ3jcXntsEpflg39wwlkiA3HseHXCcZ/FESsla+R1j1kX9/JmbZnrd/P8Y6ln3q/KyBKCGoR8JvV+/UfT27Wqi+u1NT8EVCOHPbNd+sf7ZDI2l12XsoFX5emnHX8x7RU23C25mf2m8UZF9mANmnpybUh0GkrXm35vR1i52+14/bnfMCWcs6dZgPNcCbHt9XZOXpr/9d+4E0rtKXzF90UveFcAzm8Ufr9J+0wr6Wftr/3SI8L371KeuZ2O+xm/kfsel+RGMIV9Etv/I/08t12ONQFd0qLPn7s8Oc4tlpZdzg7vN4ez5lqQ/KsS6SJy44ftlqq7OMrt4QLO2yxvQPd3/D37X3LnmgnRLccFcY6G4/93L4s+1pIL7SvueLLbS9BtAP7wbdtVcjmw3YZhCWf7H9Nf7tdO6d8tf02v3x1b2BNSre91hOX2jlzE8psQY1Y6a4w1t3jVrnFDhf0ZdnA3R3AfN3brPce86bQowAAGDSCWgTsqmrRBd97RXdfPVfXLZkU7+ZgKIIBaf1v7Afv5kPS9Atsdbux83rPaaywwxvXPWwnm8650ga0vucMV1ebrYS15fd2wcu+a5l406S0vHAZ2bz+5WS7b/eUmc2zH14Pb7DDGzc/bts86UxpyT9Jsy9PjAp3/nbbq7n6AdubcPVPI/P7bKm2FQc3PSblnia9//vRKWJRt8cOp9zzsg0Ol//A9ogEuuww1+3P2oDWVCHJ2C8CZl9qe27zZ578h/VAl500X7m5T+/bFhvMfNnh4FVoK1h2b9OL+h9LK4hvufC2OjsM9t3npZKrbEgsX2N7zA5v7A2iudNsEZqJ4Z/COYmxMC0AADFCUIsAx3G0+K4XdfaMfH3/2gXxbg5Ohr/dDmd87bt2Pknph2yv08b/s9Ww5Nj5bGd93lYQi6bOFlvevuVIn3Ky4RKy3WVl2+p658+8h7Ht9aZJ86+188+iXXL5ZO36q/SHz9g/z/lflc745+MPTQz67Ryqnt9Fn99L9/Edz9m5RmffbpcEiGYgcRz7+nj+3+28jWkrpAOr7ZwjT4p02nm212zmRTZARVMwMLIK/YRCtvjOS9+ywxi7Ky9OWGyD74TFibsYLQAAMUJQi5Bbf/O23jlQr7/dcR7z1Eay9gbp7/fZidT+Nju87PSPSctvk7ITrLfU32HnyPVdC6Q7vKQX2JLRkSpJHU2ttXbu1/Y/SZOXS+MWHiOM1R1/GJ9kQ2lqnl389H3fkgpmxrb9f/6KLWc/bYXtNZu2MiHL/Cac6ndtsB0zb3RUzgMAIIIIahHy8N/36c6ntui1fztXE3MTdK0sDF7zEdvbM/19ib1w6WjhOHYI6gtftr1DqXlSak7/4Z4puX2Ge+b2Px7PoXwAAABRcKKgNoLG0cTf0mm2WtfqvXUEtdEgY4y08IZ4t+LUYYz9fcdjjTUAAIARZhQtqBJ9MwszlJ3q1eo9tfFuCjByEdIAAAAGRFAbApfLaMmUXK3eWxfvpgAAAAAYxQhqQ7Rkaq4O1LXpcGP7wCcDAAAAwEkgqA3Rsml5kqS36FUDAAAAECUEtSEqHpupDJ9Hb+4hqAEAAACIDoLaELldRoun5Gr1XgqKAAAAAIgOgtpJWDI1V3uqW1Xd3BnvpgAAAAAYhQhqJ2HpVLueGvPUAAAAAEQDQe0klI7PUmqSm+GPAAAAAKKCoHYSvG6XFk3O0WoKigAAAACIAoLaSVo2LU87KptV19oV76YAAAAAGGUIaidpSXie2pp99KoBAAAAiCyC2kmaNyFLyR4Xwx8BAAAARBxB7SQle9w6fVIOBUUAAAAARBxBbRiWTsvV1sNNamz3x7spAAAAAEYRgtowLJmaK8eR1u1n+CMAAACAyCGoDcPpk3KU5GaeGgAAAIDIIqgNg8/r1vyJWXpzL0ENAAAAQOQQ1IZp6dQ8bT7YqJbOQLybAgAAAGCUIKgN05KpuQqGHL29vz7eTQEAAAAwShDUhmnR5By5XYYy/QAAAAAihqA2TGnJHs0dn0VBEQAAAAARQ1CLgKXTcrWhokHtXcF4NwUAAADAKEBQi4BlU/PkDzp6p5x5agAAAACGj6AWAYum5MhlxPBHAAAAABFBUIuATJ9Xc8ZlUlAEAAAAQEQQ1CJk6dQ8vXOgQZ0B5qkBAAAAGB6CWoQsnZqrzkBIG8ob490UAAAAACMcQS1ClkzNlTHSs5sOx7spAAAAAEY4glqEZKcm6brFk/SLv+/Ti9sq490cAAAAACMYQS2C7rx8jkrGZerz/7deB2rb4t0cAAAAACMUQS2CfF63fnz9IknSLb9Zpw4/hUUAAAAADB1BLcIm5aXq+9cu0JZDTfr6U1vi3RwAAAAAIxBBLQrOLy7SreeepkfXlOuxteXxbg4AAACAEYagFiW3v2+WzjwtT1/9w2ZtOUTJfgAAAACDR1CLErfL6L6PLFR2qlef+c3bamz3x7tJAAAAAEYIgloU5acn60fXn66D9e364u82yHGceDcJAAAAwAhAUIuyRZNz9eVLi/WXrZX6yat74t0cAAAAACMAQS0G/nH5FF02b6z+6/ntemN3bbybAwAAACDBEdRiwBij73xwnqbkp+mfH3lHlU0d8W4SAAAAgARGUIuR9GSPHrhhkVo7A/rsb9+WPxiKd5MAAAAAJCiCWgzNLMrQ3R+cqzX76vVfz2+Pd3MAAAAAJCiCWoxduWC8PnbGZP30tb16btPheDcHAAAAQAIaVFAzxlxsjNlhjNlljLnjGPdPMsasMsa8Y4zZaIy5NPJNHT3+47JiLZiYrX99fKP2VLfEuzkAAAAAEsyAQc0Y45Z0v6RLJM2R9BFjzJyjTvuKpMccx1ko6TpJP4p0Q0eTZI9b919/urxuo1t+/bbaugLxbhIAAACABDKYHrUlknY5jrPHcZwuSY9KuvKocxxJmeH9LEmHItfE0Wl8dop+cN1CvVvVrK88uZnFsAEAAAD0GExQGy+pvM/tivCxvr4u6QZjTIWkZyX9c0RaN8qdM7NAnzt/pn7/zkH99q0D8W4OAAAAgAQRqWIiH5H0C8dxJki6VNKvjDHveW5jzM3GmLXGmLXV1dURuvTI9s/nTdfKWQX6xlNbtbGiId7NAQAAAJAABhPUDkqa2Of2hPCxvj4h6TFJchznDUk+SflHP5HjOA86jlPmOE5ZQUHBybV4lHG5jL5/zQIVZCTr079ap701rfFuEgAAAIA4G0xQWyNphjFmqjEmSbZYyFNHnXNA0vmSZIwplg1qdJkNUk5akn7yD4vUEQjp6h/9Tev218W7SQAAAADiaMCg5jhOQNJnJb0gaZtsdcctxphvGmOuCJ/2BUmfNMZskPSIpJscqmMMSen4LP3+ljOVnZqkj/x0tZ7ZyBprAAAAwKnKxCtPlZWVOWvXro3LtRNZfWuXPvnLtVq7v17/fsls3XzONBlj4t0sAAAAABFmjFnnOE7Zse6LVDERREhOWpJ+/U9L9f55Y/Wfz23XV/6wWYFgKN7NAgAAABBDnng3AO/l87p133ULNSEnVQ+8sluHGtr1Px89XWnJ/HUBAAAApwJ61BKUy2V0xyWzdddVpXp1Z42u+ckbqmzqiHezAAAAAMQAQS3BXb90sn52Y5n21bTqqvv/pu1HmuLdJAAAAABRRlAbAc6dVajHPn2Ggo6jD//4Db2+sybeTQIAAAAQRQS1EaJkXJae/Mxyjc9J0U0PvaXH1pTHu0kAAAAAooSgNoKMy07R7z59hs44LU//9sRGfffPO8RydQAAAMDoQ1AbYTJ8Xv38psW6tmyifvjSLt3+2AZ1BoLxbhYAAACACKLe+wjkdbt09wfnamJuiu7587s63Niun9xQpqxUb7ybBgAAACAC6FEboYwx+ux5M3TvtQv09v4GXf3jv2lvTWu8mwUAAAAgAghqI9wHFo7XLz+xRDUtXbro+6/q7ue2q6UzEO9mAQAAABgGgtoosGxanv78+XP0/vlj9cAru7Xyv1/WY2vKFQpRaAQAAAAYiQhqo0RRpk/fu2aB/nDrck3KTdG/PbFRV9z/ut7aWxfvpgEAAAAYIoLaKLNgYraeuOVM/eC6Bapt6dI1P3lDt/72bVXUt8W7aQAAAAAGiaA2ChljdOWC8XrxCyv0L+fP0IvbKnX+d1/Rd/+8Q63MXwMAAAASHkFtFEtN8ujz75upl76wUheXjtEPX9ql8777sp5YV8H8NQAAACCBEdROAeOyU/SD6xbqiVvO1JhMn77wuw266sd/17r99fFuGgAAAIBjIKidQhZNztGTn1mu7354vg43tOuDP/67/uXRd3SooT3eTQMAAADQhyfeDUBsuVxGH1w0QReXjtGPX96tB1/boxe2HNHN55ymf1g2WQUZyfFuIgAAAHDKM44Tn7lKZWVlztq1a+NybfQqr2vT3c9t1zObDsvjMnrfnCJ9ZMkknTU9Xy6XiXfzAAAAgFHLGLPOcZyyY95HUIMk7apq0f+tOaDH11Wovs2vCTkpurZsoq5ZPFFFmb54Nw8AAAAYdQhqGLTOQFB/3lKpR946oL/vrpXbZXTurEJ9dOlErZhZKDe9bAAAAEBEnCioMUcN/SR73Lp8/jhdPn+c9tW06v/Wlut3ayv0122VGpvl0zXhXrbx2SnxbioAAAAwatGjhgH5gyG9uK1Sv32rXK/trJYkrZxZoOuWTNJ5swvldVM8FAAAABgqhj4iYsrr2vS7teX6v7XlqmzqVEFGsj68aII+XDZRU/PT4t08AAAAYMQgqCHiAsGQXt5RrUfeOqBVO6oUcqTS8Zm6fN44XTZvrCbkpMa7iQAAAEBCI6ghqo40duhPGw/p6Y2HtaG8QZJ0+qRsXT5/nC6bO1aFVI0EAAAA3oOghpg5UNumP206pKc3HNa2w00yRlo6NVeXzx+nS0rHKjctKd5NBAAAABICQQ1xsauqRX/aeEhPbTikPdWtcruMlk/P1+XzxurCkjHKSvHGu4kAAABA3BDUEFeO42jb4WY9vfGQ/rTxkMrr2pXkdmnFrAK9f95YXVBcpLRkVooAAADAqYWghoThOI42VDTq6Q02tFU2dSrZ49IZp+Xp3FmFWjmrQJPzqB4JAACA0Y+ghoQUCjlau79ez246rFferdbemlZJ0rT8NK2YVaCVswq1dGqufF53nFsKAAAARB5BDSPCvppWvbyjSi+/W603dteqMxBSitcd7m2zwW1iLmX/AQAAMDoQ1DDidPiDemNPrV7eboPb/to2SdJpBWlaOatQ584q1OKpOUr20NsGAACAkYmghhFvb02rVoVD25t7atUVCCk1ya0zT8vT2TMKtHhKrmaNyZDbZeLdVAAAAGBQCGoYVdq6AnpzT61Wba/Wy+9WqbyuXZKU4fNo8ZRcLZ6SqyVTczR3fLaSPK44txYAAAA4thMFNWqiY8RJTfLovNlFOm92kSSpor5Na/bV6a299Xprb61e2l4lSUr2uLRwUraWTMnVkql5Wjgpm2UAAAAAMCLQo4ZRp7alU2v21eutvXVas69OWw41KuRIbpdR6bhMLZma29PzlpOWFO/mAgAA4BTF0Eec0po7/Hr7QIPW7K3TW/vqtL68QV2BkCRpZlG6zjwtX+fMzNfSqXn0uAEAACBmCGpAHx3+oDYdbNRbe+v05p5avbW3Tp2BkLxuo4WTcnT29HydPbNAc8dnUZwEAAAAUUNQA06gwx/Uuv31em1njV7bWa0th5okSVkpXp15Wp7OmpGvc2YUsIYbAAAAIoqgBgxBbUun/ra7Vq/vrNZrO2t0uLFDkjQ5L1VnTc/X2TPydcZp+cpK8ca5pQAAABjJCGrASXIcR7urW/X6zmq9vqtGb+yuVWtXUC4jzZ+YrTOm5WnhpBwtmJitgozkeDcXAAAAIwhBDYgQfzCkdw402N62XTXaVNGoQMj+G5qYm6IFE3O0cGK2Fk7K1pxxmUr2uOPcYgAAACQqghoQJR3+oDYfbNQ7Bxr0Tnm91h9o0KHwUMkkt0sl4zO1YGK2Fk6yAW5CToqMoUAJAAAACGpATB1p7ND68nob3g40aOPBBnX47XIA+enJWjgpOxzeslU6PkuZPua6AQAAnIpOFNRYNAqIsDFZPl2cNVYXl46VZIdL7jjSrHfKG/TOAdvr9petlT3nT85LVem4LJWMz1TpuCyVjs9SLgtxAwAAnNLoUQPioKGtS+vLG7TlUJO2HGrUpoONKq9r77l/XJZPJeOzwsEtU6Xjs1SYkcywSQAAgFGEHjUgwWSnJmnlrEKtnFXYc6yxza8thxq15VCTNh9q1OaDjfrrtkp1f5eSn55sQ1s4vM0Zm6UJOSlysSg3AADAqENQAxJEVqpXZ07P15nT83uOtXYGtO1wkzYfbNTmQ3b72s4aBcOVJtOS3Jo5JkOzx2Rq9piM8E+mslKZ9wYAADCSMfQRGGE6/EFtP9KsbYebtCO83X6kWY3t/p5zxmb5NHtMhmaNyVTx2AzNGpOhafnpSvK44thyAAAA9MXQR2AU8XndWjDRVo7s5jiOKps6te2IDW/bw+Ht9V018gftlzFet9FpBek9AW5mUbpmFmVofDbDJwEAABINQQ0YBYwxGpPl05gsn87tM++tKxDSnpqWcM9bs3YcadLqvXX6w/pDPeekeN2aXpiuGYXpmlGUoRmFNsAx/w0AACB+GPoInIIa2/zaWdWsnVUt2lnZYvcrW3SkqaPnHJ/XFQ5wGZoeDm8zCtM1MTdVbgIcAADAsDH0EUA/WalelU3JVdmU3H7HG9v92lXVol1VzXq3skU7q1q0ek+tnnznYM85yR4b4IrH2gImxWMzVTw2k7XfAAAAIoigBqBHVopXiybnaNHknH7HmztsgLM9cM3aUdmiV96t1uPrKnrOKcxItuFtbIaKx9jwNq0gTV43BUwAAACGiqAGYEAZPq8WTsrRwkn9A1xNS6e2H27W9iNN2nq4SdsPN+uN3bXqCoYkSUlu2/vWN7zNHpuh/PTkePwxAAAARgyCGoCTlp+erLNmJOusGb1rv/mDIe2pbu0X3l7fWaPfv32w3+Ns9Un7UzwmUzOK0uXzuuPxxwAAAEg4BDUAEeV1u3oC2JULxvccr23p7Lf+2/Yjzfr1m/vVGbC9by4jTclL63ls9yLek3JTqT4JAABOOQQ1ADGRl56s5dOTtXx6b+9bMORof22rXT7giF0+YNvhJj2/5Yi6C9KmeN2aWZTeL7wVj81UDsVLAADAKDao8vzGmIsl/UCSW9LPHMe5+6j7vy/p3PDNVEmFjuNkn+g5Kc8P4HjaugLaWdmi7Ufswt07wj+1rV0954zN8mlOuOJk8dhMzRmXqcn0vgEAgBFkWOX5jTFuSfdLep+kCklrjDFPOY6ztfscx3E+3+f8f5a0cNitBnDKSk3yaP7EbM2fmN3veHVzp7Ydbur52Xq4SS+/W61gyAk/zt3T4zZnXGbPEgKpSQweAAAAI8tgPr0skbTLcZw9kmSMeVTSlZK2Huf8j0i6MzLNA4BeBRnJKsgo0DkzC3qOdfiD2lnZ0hPcth5u0lMbDuk3qw9IkoyRpualhXveMjQhJzX8PMkqSE9WdqpXxtALBwAAEstggtp4SeV9bldIWnqsE40xkyVNlfTS8JsGAAPzed2aOyFLcydk9RxzHEcV9e3hnrdmbT3cqE0HG/XMpsPvebzXbZSfnqz89N7w1hPkMvofS0umZw4AAMRGpD91XCfpccdxgse60xhzs6SbJWnSpEkRvjQAWMYYTcxN1cTcVF1YMqbneGtnQJVNHapu7lR1S6fddv+0dKqyqUObDzaqtrWrZzhlXxk+j04rSNeMwnTNKErX9MJ0zSjM0PjsFObGAQCAiBpMUDsoaWKf2xPCx47lOkm3Hu+JHMd5UNKDki0mMsg2AkBEpCV7NK0gXdMK0k94XjDkqL6tSzVHhbmDDe3aVdWil9+t1u/WVfSc7/O6+gS4DE0vtCFucm6qPG5XtP9YAABgFBpMUFsjaYYxZqpsQLtO0kePPskYM1tSjqQ3ItpCAIgxt6t3OOTsMcc+p7HNr13VzdpZ2aKdVfZnzb56/WH9oZ5zktwuTc1P6w1ueakan52iCbmpKspIJsQBAIDjGjCoOY4TMMZ8VtILsuX5f+44zhZjzDclrXUc56nwqddJetQZTL1/ABjhslK9WjQ5V4sm5/Y73tIZ0O6qFu0Kh7ddVc3afKhRz24+rL7/O7pdRmMyfZqQk6LxOSmakJ2iCTmpdj8nRWOzUpTkIcgBAHCqGtQ6atHAOmoATiUd/qAONbSror5dBxvadbC+XRX1bT37R5o61HdanDFSYUayDW/ZKRqXnaLCPgVOuvfTkz1UrQQAYIQa1jpqAIDh83ndJ5wf5w+GdKSxoyfIVdS36WB4f315g57bfFj+4Hu/WPN5XSrM8PWrTtk/0Nn78tKT5GWoJQAAIwZBDQASgNft6qlUeSyO46ix3a/q5k5V9SlwUtXcW8Vyd3WL3txbq4Y2/3seb4yUl2ZDXFFmsooyfSrM9IVv+3qO5aUlMXcOAIAEQFADgBHAGKPs1CRlpyZpRlHGCc/tDARV09LVL8xVNfVuK5s7tPlQk2paOnX06HeXkfLSw2Euw6fCcICbkJOqCTkpmpibqjGZPrlZjgAAgKgiqAHAKJPscWt8dorGZ6ec8LxAMKTa1i5VNnWoMhzkKps6VdXUocqmDh1p6tCGikbVtvYPdB6X0bhsW/RkYp8A170tSE9mXTkAAIaJoAYApyiP2xUe9ug74XldgVBPIZTy+jZV1LepvM7Oo3tpR5Wqmzv7nZ/kcWlCdriaZY7tgcvPSOpZ8qAg3c6ZS0vmLQgAgOPhXRIAcEJJHpem5KdpSn7aMe/v8Af7hLh2VdS19dzefPCw6o8xZ06SUrzungCXl5asgj5hLi+9z35akrJSvPTSAQBOKQQ1AMCw+LzunkW9j6UrEFJta6dqW7pU3dKpmuZO1bZ2qaa5UzUtdr+ivk3ryxtU19rZb5mCbm6XUU6qV7lpScpNS1JeWnLvfrrd5qYmKbfPPkVRAAAjGUENABBVSR6XxmbZRbwHEgw5amjrUk1Ll2pabJCra+1SXWuXalu7VNdi97cfaVJda5ca2v3vKYjSLSvFq8KMZI3LTtG4bF+4DT6Ny+7d+rzuCP9pAQCIDIIaACBhuF1GeenJyktP1iyduLqlZAuiNLT7bZALh7i6VttLV9vSparmDh1q6NCWQ42qael6z+NzUr3h4NYb5vqGukyfV+k+D1UuAQAxR1ADAIxYHrerZy6bik58boc/qMomG9wON7brUEO7DjV26HB4gfG39taqqSNwzMemJrmVnuxRhs+jdJ9XGd37yR6l+zzh2zbUpfe5LyXJrdQkj9KS3D37hD4AwGAQ1AAApwSf163JeWmanHfsoiiS1NIZ0JHGdh1ssEsUNHcE1NzhV0tHQC2dATV3BtTcEVBLh19VzR3h/YBaugLHHYJ5tGSPS2nJHqV43UpLdislHORSk3r305I9Gpvls8ss5NilFnLTkmQMIQ8AThUENQAAwtKTPZpemKHphQMPu+wrFHLU5g/2hLqmjoDaugJq6wr2bjuD/W63dgXU3hVUa1dQ7V0BHWrwq90fVGs4DLb7g/2u4fO6wsEtVePD69j1DXJFLEQOAKMKQQ0AgGFyuYwdBpnskbKG/3yO46ipPaCKhjYdrG/XwYZ2Hay3a9kdbGjX5oONqmvtP+fO4zIaE+6FK8hI7mlPWt+tz6P0ZLfSknqPdw/XTPa46LEDgARCUAMAIMEYY5SV6lVWapZKxh07+bV1BXoWIu8Oct3bLYea1NIZUGun7b0bDLfLKC3JrQyfV5kpXuWkepWTmqSs1D77KXabnepVdmqSclK9ykrxshQCAEQBQQ0AgBEoNWlwwzSDIUdtXYGe4NbSGQxv7fy61j73tXYG1dwRUGN7lxra/Np+pEkNbX41tPsVPNYCd2EZPo+y+4S5zBQb4DJ94W2K56jbdpvh88hLyAOAYyKoAQAwirldRhk+rzJ83pN+Dsdx1NwZUEOrXw3tXapv86uhzYa5+vC2oc2ua1ff5tfB+nY1dfjV2O6XP3jiKitpSe5+wc5W1uxfUTO9T1XNjOSj7/fK52XYJoDRh6AGAABOyBijTJ8NUpOUOujHOY6jDn+oJ7Q1tYe3HX41tvnV1BHod7yx3a/K5g7tqra9fc2dAXUFQgNexx2eI5iZ4lFeWrLy05OUl5as3PQk5aUlKT89WblpScpLt/s5qUlK8tCTByCxEdQAAEBUGGOUEl5DrijTd1LP0RkIhodk+u1yCOEhm73LJfQun9DQZhc/P9jQoY0VtuBK4DhDNjN9HuWnJysvPSkc4pJtb16SR6nJ4aIrfYuxJHmUFj5G8RUAsUBQAwAACSvZ41ayx63ctKQhPzYUctTU4Vdta5dqW7pU29LZu9/avd+pPdWtWrOvXi0dAXUFB+7Bk3qLr6Qn22CXluRWstctn9ctn8dlt97urT3Wc7/XJZ+ndz/Fa4u4ZKV4lZVqF1R3sdQCcMojqAEAgFHJ5TLKTk1SdmqSTisY3GO6AqE+xVfsenetfQqxtB1VfKW1MxA+J6gOf1CN7X5V+e1+hz+kjkDv/qDbbdQzby87XHwlOzVJWSkeZack9QS67vuzUr298/iSPaynB4wSBDUAAICwJI9LSR4b7iLJcRx1BkK9Ac4fDIe4kNq7guGhm109c/Ua2/1qaAtv2/2qqG/vuf8EBTglqWeZhXSfxw7nTPYo09cd5sJFWHy25y7D5+kJhd0VOenRAxIDQQ0AACDKjDE9wyCHIxRy1NIVUGObv1+oa+kIqKnDb+fudXQXY7Hz+po77Jp73fcNtLaey6h3KGa/ENc/1NklFrzyuow8bpc8biOvy2493cdcRl63S26Xkdfde8zjMnK7DPP8gBMgqAEAAIwQLldvBc6JJ/kcgWDIFmgJB7m+ga/pqCqc3T+HG9vV2B5QU7t/0PP4BmKMlOyxc/RSvO6ewjN236MUr0upSR75wvenhu/3hfdTw3MEM7qXdejuOfQx/BOjA0ENAADgFOJxu5SV6lJW6tDX1utecqE7wDV32LXygiFH/lBIgaCjQDCkQMhRIBSSP+goEHQU7N4Pb4Mhe15nIKS2rqDa/UG199k2tvtV2RhUmz+g9i47VLStKzDgsM9uRw//7A5zmb7euXwpXndPT5873BPYve8N9/h19wba+8I9g+Fzkz29xWK69wmIiCSCGgAAAAal75ILY7JObsmFk+U4jrqCIXV0hdTmtwVcWvos0dAcHv7ZPdyzpc/Qz8a2LlXUtfUs6TCU4i5D4XUb+TxuJXtdSvb0Vv7sG+p8XpfSkjzKTvWGi914lZ1it1kpXuWkJSk7xavUJDdDQ09xBDUAAAAkPGNMz3INWRp6b2BfXYGQOgNB2xPY3SMYDNmevj69ft3HjtVr2FscJqiOQEidR1X57AwE7bFw4ZjO8OLvHf6QWjoCamjvOmFg9LqNslKSlJPqDYe43jCXEg5/yX2CoP3duHpCYs+x8HIQ9njvfRSMSXwENQAAAJxSbHVPV7yboQ5/UA1tfjW0d9ltm1+N7V2q77Pffbyivk1bDtkhp+3+oJxBDgM9nu5evpSj1/zz9j3e/76Uo87vmVPodcvXZ797LmGK1y2vm6IxJ4ugBgAAAMSBz+vWmKyhDyN1HEeBkF3yodMftNtwD9/Rx7p79rr327t7/MK9ge19loxo99tz61q71N7Vu4RER3jfHxx6OnS7TE/osz15ksflksvY+1zGzgfs+TFGrvDW7erel9wul9KT3T3zDXu3tojM0cfSkkb+MhMENQAAAGAEMcYWPvG6XUpPjt3H+UAwpI6AXfuvO9j1FIHxB9XRZ7//OSG1+4PqCoQUCofMUMgOJw064X0nfDv8EwiF1BlwFHTsshT+YEitXYGeeYfBASrLGKOeSqAZPo/GZafo5zctjtFvKjIIagAAAAAG5HG7lB7jcHgsjuOo3R8Mhza/mjoCPfv9t70FZpITYKjrUBHUAAAAAIwYxhilJnmUmuRRUWZsq4/G0siLlgAAAAAwyhHUAAAAACDBENQAAAAAIMEQ1AAAAAAgwRDUAAAAACDBENQAAAAAIMEQ1AAAAAAgwRDUAAAAACDBENQAAAAAIMEQ1AAAAAAgwRDUAAAAACDBENQAAAAAIMEQ1AAAAAAgwRjHceJzYWOqJe2Py8VPLF9STbwbgVMGrzfECq81xAqvNcQKrzXEUrReb5Mdxyk41h1xC2qJyhiz1nGcsni3A6cGXm+IFV5riBVea4gVXmuIpXi83hj6CAAAAAAJhqAGAAAAAAmGoPZeD8a7ATil8HpDrPBaQ6zwWkOs8FpDLMX89cYcNQAAAABIMPSoAQAAAECCIaj1YYy52BizwxizyxhzR7zbg9HDGPNzY0yVMWZzn2O5xpi/GGN2hrc58WwjRgdjzERjzCpjzFZjzBZjzL+Ej/N6Q8QZY3zGmLeMMRvCr7dvhI9PNcasDr+f/p8xJinebcXoYIxxG2PeMcb8KXyb1xoizhizzxizyRiz3hizNnws5u+jBLUwY4xb0v2SLpE0R9JHjDFz4tsqjCK/kHTxUcfukPSi4zgzJL0Yvg0MV0DSFxzHmSNpmaRbw/+X8XpDNHRKOs9xnPmSFki62BizTNJ3JH3fcZzpkuolfSJ+TcQo8y+StvW5zWsN0XKu4zgL+pTkj/n7KEGt1xJJuxzH2eM4TpekRyVdGec2YZRwHOdVSXVHHb5S0sPh/YclfSCWbcLo5DjOYcdx3g7vN8t+oBkvXm+IAsdqCd/0hn8cSedJejx8nNcbIsIYM0HSZZJ+Fr5txGsNsRPz91GCWq/xksr73K4IHwOipchxnMPh/SOSiuLZGIw+xpgpkhZKWi1eb4iS8FC09ZKqJP1F0m5JDY7jBMKn8H6KSLlX0r9JCoVv54nXGqLDkfRnY8w6Y8zN4WMxfx/1RPsCAAbmOI5jjKEEKyLGGJMu6QlJn3Mcp8l+8WzxekMkOY4TlLTAGJMt6UlJs+PbIoxGxpj3S6pyHGedMWZlnJuD0e8sx3EOGmMKJf3FGLO9752xeh+lR63XQUkT+9yeED4GREulMWasJIW3VXFuD0YJY4xXNqT9xnGc34cP83pDVDmO0yBplaQzJGUbY7q/DOb9FJGwXNIVxph9stNTzpP0A/FaQxQ4jnMwvK2S/QJqieLwPkpQ67VG0oxw9aAkSddJeirObcLo9pSkG8P7N0r6YxzbglEiPGfjfyVtcxzne33u4vWGiDPGFIR70mSMSZH0Ptl5kaskfSh8Gq83DJvjOP/uOM4Ex3GmyH5Ge8lxnOvFaw0RZoxJM8ZkdO9LulDSZsXhfZQFr/swxlwqO/7ZLennjuPcFd8WYbQwxjwiaaWkfEmVku6U9AdJj0maJGm/pGscxzm64AgwJMaYsyS9JmmTeudxfFl2nhqvN0SUMWae7KR6t+yXv485jvNNY8w02V6PXEnvSLrBcZzO+LUUo0l46OMXHcd5P681RFr4NfVk+KZH0m8dx7nLGJOnGL+PEtQAAAAAIMEw9BEAAAAAEgxBDQAAAAASDEENAAAAABIMQQ0AAAAAEgxBDQAAAAASDEENAAAAABIMQQ0AAAAAEgxBDQAAAAASzP8HsEQNQxNyBGwAAAAASUVORK5CYII=\n",
"text/plain": [
"<Figure size 1080x432 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"plt.figure(figsize=(15,6))\n",
"plt.plot(train_losses, label='train loss')\n",
"plt.plot(val_losses, label='validation loss')\n",
"plt.legend()"
],
"id": "I5ZlWAm3Bpjd"
},
{
"cell_type": "markdown",
"metadata": {
"id": "hdnbiBHyBpjd"
},
"source": [
"# Model Testing"
],
"id": "hdnbiBHyBpjd"
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "o5Q1MM2FBpjd",
"outputId": "75d7c927-dba0-44cd-f206-1c479567bf4a",
"colab": {
"base_uri": "https://localhost:8080/"
}
},
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Test acc: 0.6099\n"
]
}
],
"source": [
"model = torch.load('best_val_model_pytorch')\n",
"\n",
"n_correct = 0.\n",
"n_total = 0.\n",
"for inputs, targets in test_loader:\n",
" # Move to GPU\n",
" inputs, targets = inputs.to(device, dtype=torch.float), targets.to(device, dtype=torch.int64)\n",
"\n",
" # Forward pass\n",
" outputs = model(inputs)\n",
" \n",
" # Get prediction\n",
" # torch.max returns both max and argmax\n",
" _, predictions = torch.max(outputs, 1)\n",
"\n",
" # update counts\n",
" n_correct += (predictions == targets).sum().item()\n",
" n_total += targets.shape[0]\n",
"\n",
"test_acc = n_correct / n_total\n",
"print(f\"Test acc: {test_acc:.4f}\")"
],
"id": "o5Q1MM2FBpjd"
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "Z15vTBfMBpjd"
},
"outputs": [],
"source": [
"# model = torch.load('best_val_model_pytorch')\n",
"all_targets = []\n",
"all_predictions = []\n",
"\n",
"for inputs, targets in test_loader:\n",
" # Move to GPU\n",
" inputs, targets = inputs.to(device, dtype=torch.float), targets.to(device, dtype=torch.int64)\n",
"\n",
" # Forward pass\n",
" outputs = model(inputs)\n",
" \n",
" # Get prediction\n",
" # torch.max returns both max and argmax\n",
" _, predictions = torch.max(outputs, 1)\n",
"\n",
" all_targets.append(targets.cpu().numpy())\n",
" all_predictions.append(predictions.cpu().numpy())\n",
"\n",
"all_targets = np.concatenate(all_targets) \n",
"all_predictions = np.concatenate(all_predictions) "
],
"id": "Z15vTBfMBpjd"
},
{
"cell_type": "code",
"source": [
"print('accuracy_score:', accuracy_score(all_targets, all_predictions))\n",
"print(classification_report(all_targets, all_predictions, digits=4))"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "0csU-PiNF2gP",
"outputId": "952d1b3d-3d2b-418c-f09b-69ab4b81ed2b"
},
"id": "0csU-PiNF2gP",
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"accuracy_score: 0.6098732507455839\n",
" precision recall f1-score support\n",
"\n",
" 0 0.7163 0.4758 0.5718 47915\n",
" 1 0.7589 0.6107 0.6768 48050\n",
" 2 0.4773 0.7566 0.5853 43523\n",
"\n",
" accuracy 0.6099 139488\n",
" macro avg 0.6508 0.6144 0.6113 139488\n",
"weighted avg 0.6564 0.6099 0.6122 139488\n",
"\n"
]
}
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"scrolled": true,
"id": "dU8zeoyLBpje",
"outputId": "b5980522-48e9-4e1b-f9d8-3e65ca7bde9e"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"accuracy_score: 0.753498508832301\n",
" precision recall f1-score support\n",
"\n",
" 0 0.7341 0.7524 0.7431 47915\n",
" 1 0.8074 0.7622 0.7841 48050\n",
" 2 0.7204 0.7451 0.7325 43523\n",
"\n",
" accuracy 0.7535 139488\n",
" macro avg 0.7540 0.7532 0.7533 139488\n",
"weighted avg 0.7551 0.7535 0.7540 139488\n",
"\n"
]
}
],
"source": [
"print('accuracy_score:', accuracy_score(all_targets, all_predictions))\n",
"print(classification_report(all_targets, all_predictions, digits=4))"
],
"id": "dU8zeoyLBpje"
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.12"
},
"colab": {
"name": "DeepLOB.ipynb",
"provenance": [],
"machine_shape": "hm",
"include_colab_link": true
},
"accelerator": "GPU"
},
"nbformat": 4,
"nbformat_minor": 5
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment