Created
April 8, 2021 16:44
-
-
Save oguiza/49fddc29fd8a89e61872b1967d088f63 to your computer and use it in GitHub Desktop.
Preparing data with tsai
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"metadata": { | |
"ExecuteTime": { | |
"start_time": "2021-04-07T11:26:25.200592Z", | |
"end_time": "2021-04-07T11:26:39.934457Z" | |
}, | |
"trusted": true | |
}, | |
"cell_type": "code", | |
"source": "from tsai.all import *", | |
"execution_count": 8, | |
"outputs": [] | |
}, | |
{ | |
"metadata": {}, | |
"cell_type": "markdown", | |
"source": "# Preparing data with tsai" | |
}, | |
{ | |
"metadata": {}, | |
"cell_type": "markdown", | |
"source": "As you know, `tsai` requires input data to be numpy arrays with the following format: \n\n[samples x vars x steps]\n\n* vars = # variables, features or channels\n\n* steps = # steps in the sequence, sequence length, or timesteps\n\nBut sometimes all you have is a dataframe containing univariate or multivariate data that you need to convert to the required format. Let's see how `tsai` can help you do that. " | |
}, | |
{ | |
"metadata": {}, | |
"cell_type": "markdown", | |
"source": "## Applying a sliding window to a dataframe: " | |
}, | |
{ | |
"metadata": {}, | |
"cell_type": "markdown", | |
"source": "A sliding window allows to create X (and y) numpy arrays with multiple subsequences from a long sequence (in a pandas dataframe). The long sequence can be univariate or multivariate." | |
}, | |
{ | |
"metadata": { | |
"ExecuteTime": { | |
"start_time": "2021-04-08T15:26:41.027126Z", | |
"end_time": "2021-04-08T15:26:41.058274Z" | |
}, | |
"trusted": true | |
}, | |
"cell_type": "code", | |
"source": "x = np.arange(0, 2_000, 2).reshape(-1,1)\ny = np.arange(1, 2_000, 2).reshape(-1,1)\ntarget = np.random.randint(0, 10, 1_000).reshape(-1,1)\ndata = np.concatenate([x,y, target], -1)\ndf = pd.DataFrame(data, columns=['x', 'y', 'target'])\ndf.head()", | |
"execution_count": 169, | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"execution_count": 169, | |
"data": { | |
"text/plain": " x y target\n0 0 1 1\n1 2 3 1\n2 4 5 0\n3 6 7 8\n4 8 9 3", | |
"text/html": "<div>\n<style scoped>\n .dataframe tbody tr th:only-of-type {\n vertical-align: middle;\n }\n\n .dataframe tbody tr th {\n vertical-align: top;\n }\n\n .dataframe thead th {\n text-align: right;\n }\n</style>\n<table border=\"1\" class=\"dataframe\">\n <thead>\n <tr style=\"text-align: right;\">\n <th></th>\n <th>x</th>\n <th>y</th>\n <th>target</th>\n </tr>\n </thead>\n <tbody>\n <tr>\n <th>0</th>\n <td>0</td>\n <td>1</td>\n <td>1</td>\n </tr>\n <tr>\n <th>1</th>\n <td>2</td>\n <td>3</td>\n <td>1</td>\n </tr>\n <tr>\n <th>2</th>\n <td>4</td>\n <td>5</td>\n <td>0</td>\n </tr>\n <tr>\n <th>3</th>\n <td>6</td>\n <td>7</td>\n <td>8</td>\n </tr>\n <tr>\n <th>4</th>\n <td>8</td>\n <td>9</td>\n <td>3</td>\n </tr>\n </tbody>\n</table>\n</div>" | |
}, | |
"metadata": {} | |
} | |
] | |
}, | |
{ | |
"metadata": { | |
"ExecuteTime": { | |
"start_time": "2021-04-08T16:36:22.420180Z", | |
"end_time": "2021-04-08T16:36:22.427345Z" | |
}, | |
"trusted": true | |
}, | |
"cell_type": "code", | |
"source": "def y_func(o): \n return scipy.stats.mode(o, axis=1).mode.ravel()", | |
"execution_count": 211, | |
"outputs": [] | |
}, | |
{ | |
"metadata": { | |
"ExecuteTime": { | |
"start_time": "2021-04-08T16:36:23.186758Z", | |
"end_time": "2021-04-08T16:36:23.253660Z" | |
}, | |
"trusted": true | |
}, | |
"cell_type": "code", | |
"source": "window_len = 10\nstride = 5 # overlapping\nstart = 0\nget_x = ['x', 'y']\nget_y = ['target']\ny_func = y_func\nhorizon = -window_len\nseq_first = True\nsort_by = None\nascending = True\ncheck_leakage = True\n\nsw = SlidingWindow(window_len, stride=stride, get_x=get_x, get_y=get_y, y_func=y_func, horizon=horizon, seq_first=seq_first)\n\nX, y = sw(df)\nX.shape, y.shape", | |
"execution_count": 212, | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"execution_count": 212, | |
"data": { | |
"text/plain": "((199, 2, 10), (199,))" | |
}, | |
"metadata": {} | |
} | |
] | |
}, | |
{ | |
"metadata": {}, | |
"cell_type": "markdown", | |
"source": "When using SlidingWindows 2 numpy arrays containing all data are created. " | |
}, | |
{ | |
"metadata": { | |
"ExecuteTime": { | |
"start_time": "2021-04-08T16:36:24.131953Z", | |
"end_time": "2021-04-08T16:36:24.162723Z" | |
}, | |
"trusted": true | |
}, | |
"cell_type": "code", | |
"source": "df[:15]", | |
"execution_count": 213, | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"execution_count": 213, | |
"data": { | |
"text/plain": " x y z target\n0 0 1 1 9\n1 2 3 3 1\n2 4 5 5 3\n3 6 7 7 1\n4 8 9 9 5\n5 10 11 11 3\n6 12 13 13 9\n7 14 15 15 3\n8 16 17 17 4\n9 18 19 19 9\n10 20 21 21 1\n11 22 23 23 6\n12 24 25 25 4\n13 26 27 27 2\n14 28 29 29 8", | |
"text/html": "<div>\n<style scoped>\n .dataframe tbody tr th:only-of-type {\n vertical-align: middle;\n }\n\n .dataframe tbody tr th {\n vertical-align: top;\n }\n\n .dataframe thead th {\n text-align: right;\n }\n</style>\n<table border=\"1\" class=\"dataframe\">\n <thead>\n <tr style=\"text-align: right;\">\n <th></th>\n <th>x</th>\n <th>y</th>\n <th>z</th>\n <th>target</th>\n </tr>\n </thead>\n <tbody>\n <tr>\n <th>0</th>\n <td>0</td>\n <td>1</td>\n <td>1</td>\n <td>9</td>\n </tr>\n <tr>\n <th>1</th>\n <td>2</td>\n <td>3</td>\n <td>3</td>\n <td>1</td>\n </tr>\n <tr>\n <th>2</th>\n <td>4</td>\n <td>5</td>\n <td>5</td>\n <td>3</td>\n </tr>\n <tr>\n <th>3</th>\n <td>6</td>\n <td>7</td>\n <td>7</td>\n <td>1</td>\n </tr>\n <tr>\n <th>4</th>\n <td>8</td>\n <td>9</td>\n <td>9</td>\n <td>5</td>\n </tr>\n <tr>\n <th>5</th>\n <td>10</td>\n <td>11</td>\n <td>11</td>\n <td>3</td>\n </tr>\n <tr>\n <th>6</th>\n <td>12</td>\n <td>13</td>\n <td>13</td>\n <td>9</td>\n </tr>\n <tr>\n <th>7</th>\n <td>14</td>\n <td>15</td>\n <td>15</td>\n <td>3</td>\n </tr>\n <tr>\n <th>8</th>\n <td>16</td>\n <td>17</td>\n <td>17</td>\n <td>4</td>\n </tr>\n <tr>\n <th>9</th>\n <td>18</td>\n <td>19</td>\n <td>19</td>\n <td>9</td>\n </tr>\n <tr>\n <th>10</th>\n <td>20</td>\n <td>21</td>\n <td>21</td>\n <td>1</td>\n </tr>\n <tr>\n <th>11</th>\n <td>22</td>\n <td>23</td>\n <td>23</td>\n <td>6</td>\n </tr>\n <tr>\n <th>12</th>\n <td>24</td>\n <td>25</td>\n <td>25</td>\n <td>4</td>\n </tr>\n <tr>\n <th>13</th>\n <td>26</td>\n <td>27</td>\n <td>27</td>\n <td>2</td>\n </tr>\n <tr>\n <th>14</th>\n <td>28</td>\n <td>29</td>\n <td>29</td>\n <td>8</td>\n </tr>\n </tbody>\n</table>\n</div>" | |
}, | |
"metadata": {} | |
} | |
] | |
}, | |
{ | |
"metadata": { | |
"ExecuteTime": { | |
"start_time": "2021-04-08T16:36:25.761950Z", | |
"end_time": "2021-04-08T16:36:25.774704Z" | |
}, | |
"trusted": true | |
}, | |
"cell_type": "code", | |
"source": "X[0], y[0]", | |
"execution_count": 214, | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"execution_count": 214, | |
"data": { | |
"text/plain": "(array([[ 0, 2, 4, 6, 8, 10, 12, 14, 16, 18],\n [ 1, 3, 5, 7, 9, 11, 13, 15, 17, 19]]),\n 3)" | |
}, | |
"metadata": {} | |
} | |
] | |
}, | |
{ | |
"metadata": { | |
"ExecuteTime": { | |
"start_time": "2021-04-08T16:36:27.457299Z", | |
"end_time": "2021-04-08T16:36:27.469003Z" | |
}, | |
"trusted": true | |
}, | |
"cell_type": "code", | |
"source": "X[1], y[1]", | |
"execution_count": 215, | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"execution_count": 215, | |
"data": { | |
"text/plain": "(array([[10, 12, 14, 16, 18, 20, 22, 24, 26, 28],\n [11, 13, 15, 17, 19, 21, 23, 25, 27, 29]]),\n 3)" | |
}, | |
"metadata": {} | |
} | |
] | |
}, | |
{ | |
"metadata": { | |
"ExecuteTime": { | |
"start_time": "2021-04-08T15:24:06.157485Z", | |
"end_time": "2021-04-08T15:24:06.839321Z" | |
}, | |
"trusted": true | |
}, | |
"cell_type": "code", | |
"source": "splits = TimeSplitter(valid_size=.2)(y)\nsplits", | |
"execution_count": 165, | |
"outputs": [ | |
{ | |
"output_type": "display_data", | |
"data": { | |
"text/plain": "<Figure size 1152x36 with 1 Axes>", | |
"image/png": "iVBORw0KGgoAAAANSUhEUgAABBwAAABTCAYAAAA82hSvAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8vihELAAAACXBIWXMAAAsTAAALEwEAmpwYAAAVd0lEQVR4nO3de1CU1xnH8R/sglxVboKKV5AoKElBjZQgY60jGjTNGBvtxKozHQfb/pGgjeNQpzYT00SnNtGxwUxn2tRMx0ujtSEQdayxSVS8EY1aIl4RBESEgC4gC9s/Mu6I7K6Iu+xu8v38xfvuu+c8u88czvBw3vP6xMXFWQQAAAAAAOBEvu4OAAAAAAAAfPdQcAAAAAAAAE5HwQEAAAAAADgdBQcAAAAAAOB0FBwAAAAAAIDTUXAAAAAAAABOR8EBAOCRfHx8lJmZqVWrVumdd97RG2+8oUWLFmnAgAHdbiM7O1t5eXmSpLS0NK1fv77H8aSlpWnNmjWSpISEBOXn5yswMLDH7d0vPz9fTz75pCRpzZo1SktLc0q799r70Y9+5LT2AAAAusvo7gAAALDlxz/+saZOnart27fr+vXr6t+/v6ZPn64VK1bo97//vRobGx+pvZMnT6q0tNR6nJ+fr3fffVenTp165NguXbqkvLw8tbS0PPTahQsXKjAwUPn5+XavycvLe+TP093+1q1b1604AQAAnI0VDgAAjzR58mR9/PHHOnnypKqrq1VaWqpNmzbJYrEoJSXlkdtrbW1VfX29U2Izm82qq6uTxWJ5rHZ8fb+dhuvq6tTW1uaM0LpoaGig4AAAANyCFQ4AAI8UFBSkiIiITufMZrM2btyoO3fuSPr2v/lGo1H19fVKT09Xa2urDh48qD179nRpLy0tTXPnzlVubq71v/9Lly5VQUGBCgoKulz/xBNP6MUXX1RUVJTKy8t1/vx562sJCQnKzc3VK6+8oubmZk2YMEEzZ85UZGSkGhoaVFRUpEOHDmnhwoXW2yPWrFmjvLw85ebmqry8XOHh4UpISNDy5cu7rLaIjIzUK6+8ohEjRqi2tla7d+/W6dOnre3s379f//nPfyRJERERWrNmjV5//XVNnTq1S3/3X+/j46Ps7Gylp6erT58+unr1qj788ENdu3ZNkpSbm6srV64oKChIqampam9v1/79+1VUVNTzRAIAgO8tVjgAADzS8ePHlZWVpWXLlikrK0vx8fEyGo26evWqbt68ab3uBz/4gfz9/bVu3Trt3LlTM2bM0DPPPOOw7Xv7OnzwwQfav39/l9dDQ0O1dOlSnT9/Xm+++aYOHz6sqVOn2mxr4MCBWrRokYqKivT6669r3759eumllzRkyBD985//1MmTJ3Xu3DmtW7fO+p7MzEyVl5d3One/adOm6eTJk3rrrbd09uxZ5eTkaNCgQQ/9zuz1d8/MmTOVkZGhrVu3au3atbp8+bKWLVumsLAw6zVTpkxRQ0OD1q5dqwMHDui5555TTEzMQ/sGAAB4ECscAAAeadu2baqurtb48eM1a9YsGQwGtbS06PDhw9qxY4c6OjokSY2Njdq2bZssFouqq6s1YsQIZWZm6vPPP7fbdl1dnSSpqalJzc3NXV7PyMhQfX29tm7dKkmqrKxUbGysxo0b1+Xae5tYXrlyRTdu3FBNTY1MJpPu3r2rO3fuqLW1Vb6+vmpoaLC+58KFC/rkk0/sxvf555/r4MGDkqSdO3cqMTFRzzzzjLZv3+7wO7PXnyQZjUZNmzZN27Zt05dffilJ2r17t5544glNmTJFO3futH6Oeys+CgsLlZWVpcGDB6u6utph3wAAAA+i4AAA8EgdHR06cOCADhw4ID8/P40cOVITJ05UZmammpqaVFhYKEkqLy/vtJdCRUWF0tPTH6vv2NhYXbp0qdO5q1ev2iw4lJaW6sKFC1q1apXKysp0/vx5lZSUqKamxm77VVVVDvu/cuVKl+PIyMjufwAbIiMjFRAQ0OnWEOnb7ys6OrrT8T0Wi0VtbW3y9/d/rL4BAMD3EwUHAIDHGTlypCZPnqy//e1vkqS2tjZ9/fXX+vrrr2UwGDR69GhrwcGWx93M0WAwdDl3b4PHB7W2tupPf/qThgwZojFjxigxMVGzZ8/We++9Z11J8KjxPfi6wWDQ3bt3bV5rNHZvKvfz85Mktbe3dzrv7+/fqe17K0cAAAAeF3s4AAA8jtls1qRJk2zuW3D37t1OfyAPHjy40+sjR4586AqCh6mqqtKIESO6tGtLamqqsrKydO3aNe3du1dvv/22zp49q6eeeqrH/Q8fPrzLcWVlpaRvv5t7xQNJ3d5foba2Vu3t7V0+V3x8fKdVDQAAAM5CwQEA4HHKy8v11VdfKScnRykpKYqOjlZ8fLyys7OVlpam//73v9ZrIyIi9Pzzz2vgwIHKyMjQpEmTrPsfONLW1qbBgwcrKCioy2sHDx5UZGSk5s6dq8GDBysjI8NuAcFkMunZZ59Venq6YmJilJycrOHDh+vy5cvWfsLCwh7plogf/vCHevrppzVo0CDNmTNHYWFh+uyzzyR9WwxJTk6Wv7+/QkJCNH369C6fy1Z/LS0t+uKLL/TCCy9o7NixGjJkiH72s58pMDDQ2jYAAIAzcUsFAMAjbd68WVOnTtWzzz6ryMhItbS06PLly9qwYYPKysqs1507d07BwcFasWKFTCaTdu3apeLi4oe2f+jQIc2YMUNms1l79+7t9Fp9fb3effddvfjii5o8ebLKysr0r3/9S1lZWV3a+d///qddu3Zp+vTp6t+/v5qamnTw4EFrUeTEiRNKSUnRL3/5S7322mvd+uyFhYWaPHmyhg4dqpqaGm3atEm3b9+W9O0mkosXL9a6det069YtFRUVdVp94ai/HTt2SJIWL14sPz8/Xb16VRs2bJDJZOpWXAAAAI/CJy4u7vFudAUAwE0WLlyowMBA5efnuzsUAAAAPIBbKgAAAAAAgNNRcAAAAAAAAE7HLRUAAAAAAMDpWOEAAAAAAACcrteeUhEQEKDY2Fg1NTWpvb29t7oFAAAAAPQCg8Gg0NBQVVRUqKWlxd3hwAP0WsEhNjZWU6ZM6a3uAAAAAABucODAAV24cMHdYcAD9FrBoampSZK0dWu2amsjeqtbAAAAwKON/uAld4cAOEXonVBNOjPJ+rcf0GsFh3u3UdTWRqiyMqa3ugUAAAA8WnTfeneHADgVt9DjHjaNBAAAAAAATkfBAQAAAAAAOF2v3VIBAAAAAIA3MRqNCgwMdHcYHqm5uVlms9nhNaxwAAAAAADgAbGxsYqI4IEH9kRERCg2NtbhNaxwAAAAAADgPkajUW1tbaqpqXF3KB6rqalJMTExMhqNdlc6sMIBAAAAAID7BAYGymQyuTsMj2cymRzectLtgsOiRYuUnp7ulKAAAAAAAIB3s1gsDl9/6C0VSUlJSkpK0sSJE1VWVua0wAAAAAAAwHfXQwsOw4YNk9FoVGNjY2/EAwAAAACAR7p48YLL+4iLi3d5H73loQWHwsJCSVJMTIzLgwEAAAAAAN3z8ssva9SoUZIkg8Ggjo4O620OR44c0ZYtW7rVzqhRo7Ro0SLl5eU5NT6XPKUiOztb2dnZnc7duXNHpaWlrugOAAAAAIDvnbffftv6c25urs6fP6+CgoIu1/n6+qqjo8NuO2VlZU4vNkguKjgUFBR0+ZCRkZF6/vnnXdEdAAAAAAC4T1pamiZNmqSGhgYNGzZMq1ev1vjx4zV79mz1799fdXV12r17t7788kslJCRo8eLFWrlypbKzsxUdHS0fHx+NGTNGJpNJf/3rX3Xp0qVHjoHHYgIAAAAA8B00atQolZaW6rXXXpOfn59+/vOf6/3339fLL7+svXv3asGCBTbfl5KSoqNHj2rFihUqLS3V7Nmze9Q/BQcAAAAAAL6DampqdPjwYeveDuvWrdPFixcVEhIiHx8fBQcHy9e3a1mgtLRUp0+fltlsVklJicLDw3vUv0tuqQAAAAAAAO51584d688Wi0WZmZkaO3asbt26pZqaGrvvu337tvXnjo4OGQyGHvXf7YLD+vXre9QBAAAAAABwr6efflpDhw7Vb3/7W5nNZsXGxiotLc2lfbLCAQAAAACA7ziDwSBfX1/5+fmpf//+mjVrliTJaHRdWYCCAwAAAAAA3RAXF+/uEHrsyJEjGjt2rN566y3V1tZq586d6tevn5YsWaK9e/e6pE+fuLg4i0tafsC9x2Ju3LhQlZUxvdElAAAA4PFSTqS6OwTAKcIawzSteJp27dqlmzdvujucxxIaGipJampqcnMknu1h3xNPqQAAAAAAAE5HwQEAAAAAADgdBQcAAAAAAOB0FBwAAAAAAIDT9dpTKgwGgyQpKqqut7oEAAAAPF5YY5i7QwCcIvTOtxsI3vvbD+i1gsO93SvnzSvorS4BAAAAz1c8zd0RAE4VGhqqmpoad4cBD9BrBYeKigoNHz5cGzZsUHt7e291CydauXKl/vCHP7g7DDwGcujdyJ93I3/ejfx5N/Ln3cif9zAYDAoNDVVFRYW7Q4GH6LWCQ0tLiyIiIqh0ebHg4GCvf57u9x059G7kz7uRP+9G/rwb+fNu5M+7fNf/3uu3vZ/L+/jmp9+4vI/ewqaRAAAAAAB4oZkzZ+p3v/tdl/OpqanasGGDAgIC7L43NzdX6enpkqRNmzYpMjLS5nVr1qxRQkJCj+Kj4AAAAAAAgBcqLi5WdHS0Bg0a1Ol8amqqTp06pZaWlm6186tf/colK4l67ZYKAAAAAADgPHV1dbp48aLGjx+vf//735Ikf39/jR07Vps3b1Z4eLgWLFigkSNHqrW1VSdPntT27dvV0dHRqZ38/HytWrVKtbW1Sk1N1Zw5cxQUFKTi4mL5+Pj0OL5eXeFQUMATKrwZ+fN+5NC7kT/vRv68G/nzbuTPu5E/wLHi4mKlpqZaj5OTk9Xc3Kxz587pueee0/Xr17Vs2TK9+eabSk5O1rhx4+y2FRYWpgULFugf//iHXn31VZlMJoWHh/c4NgoO6Dby5/3IoXcjf96N/Hk38ufdyJ93I3+AYydOnFB4eLiGDBkiSUpJSdHRo0dlsVj0ySef6KOPPpLBYFBwcLDMZrNCQkLstjVhwgSdPXtWZ86c0d27d/XRRx+pubm5x7FxSwUAAAAAAF6qublZp0+fVmpqqm7cuKGxY8dq7dq1kqTBgwdr6dKlam9vV2Vl5UNvj4iIiFBdXZ31uKOjQ01NTT2OjYIDAAAAAABerLi4WHPnzlVFRYVu3LihiooK+fn5aeHChVq/fr0uX74sScrLy3PYTmNjY6cNKI1Go/r27dvjuHhKBQAAAAAAXuzMmTMKCAhQdna2iouLJUm+vr7y9fWVn5+fAgIClJmZqUGDBslotL/u4MSJE0pKSlJiYqL8/f01e/Zs+fv79zguVjgAAAAAANAN3/z0G3eHYFNHR4eOHz+uyZMnWwsOra2t2rZtm5YsWSJJOnTokHbt2qU5c+bo1KlTNtuprq7Wli1bNH/+fIWEhOizzz5TZWVlj+PyiYuLs/T43QAAAAAAfMeEhoZK0mPtX/B98LDvqVdWOMTHx2v+/PmKiorS1atXtWXLFt24caM3ukYPJSYmas6cOYqKitKtW7f08ccf69ixY8rJyVFSUpL1utu3b2vlypVujBS22MsTY9HzzZgxQzNnzux0zsfHR0eOHFFISAjjz8MtWrRIZWVl+uKLLyQ5nv8Yj57nwfzZmwsl+79n4T4P5s9Rjhh/nuf+/DmaCz/44APGH+BFXF5wCAgIUE5Ojj788EOVlJRo2rRp+sUvfqE33njD1V2jh4KDg7VkyRJt375dx44d0+jRo7VkyRJdv35d0dHRWr16daedS+F5bOWJsegdioqKVFRUZD3u27evfvOb32jfvn3Kyclh/HmopKQkJSUlaeLEiSorK5PkeMwxHj2Lrfw5mgsrKyuZDz2IrfxJtudCifnQ09jKn6O5ULKfWwCex+WbRj755JO6efOmDh8+rJaWFhUWFiomJkYDBw50ddfooVGjRqmurk6HDh1SW1ubvvrqK12/fl1jxoxRv379dOvWLXeHiIewlSfGond66aWX9Omnn6qmpobx58GGDRsmo9GoxsZG6zlHY47x6Fls5c/eXDh69GhJtn/Pwj1s5U+ynyPGn2exl7/73T8XSow/wJu4fIVDbGysysvLrcft7e2qqanRgAEDVFVV5eru0QMXLlzQX/7yF+txcHCwIiMj1dTUJIvFouXLl2vgwIGqrq7Wjh07rI9YgWcICwuzmSfGovcZN26cBg4cqM2bN9vNK+PPMxQWFkqSYmJirOccjTnGo2exlT97c2F9fT3j0cPYyp+jHDH+PIut/N3v/rlQcpxbAL3Px8dHFov9bSFdvsIhMDBQJpOp07mWlhYFBAS4umv00O3bt60Tbnx8vF599VVdu3ZNVVVVun79unbs2KEVK1bo6NGj+vWvf23dKASeISQkxGaegoODGYteZsaMGSooKFB7e7vdvDL+PJej+Y+50fPZmwtLSkoYj17AUY4Yf97l/rlQcpxbwJmam5sVFBTk7jA8XlBQkJqbm+2+7vIVDiaTqctzO/v06dPlFz08S0BAgObNm6ennnpKe/bs0Z49e9TR0aE//vGP1ms+/fRTZWRkKD4+XiUlJW6MFve7du2a3TydOXOm07WMRc81fPhwRUdH6/jx45Ic55Xx55kczX/Mjd7B1lxosVgYj17AUY4Yf97jwblQYj5E7zGbzfLz81NMTIxMJpPD/+J/H/n4+CgoKEhGo1Fms9nudS5f4VBVVaXY2FjrscFgUFRUlK5du+bqrtFDfn5+Wr58ufr27avVq1erqKhIHR0dSkxMVHJycqdrDQaDWltb3RQpbLGXp3379jEWvUhGRoaOHz9u/Y8O48/7OJr/mBs9n725UGI8egNHOWL8eY8H50KJ8YfeVVFRoZs3b1JssMFisejmzZuqqKhweJ3LVziUlJTohRdeUHJyskpLSzVr1ixduXJFDQ0Nru4aPTRhwgQZjUb9+c9/7lSt6tOnj+bNm6eGhgZVVVUpPT1d/v7+nXaEhvvZy9OxY8f0k5/8hLHoJUaPHq2tW7dajxl/3sfR/Mfc6PnszYUS49EbOMqRwWBg/HmJB+dCifGH3mc2m9XU1OTuMLyWywsOLS0teu+99zR//nyFh4fr4sWLev/9913dLR7D0KFDNWDAAL3zzjudzv/973/X/v37lZOTo8DAQJWXl2vjxo1qa2tzU6SwpaSkRFFRUV3yxFj0HuHh4YqIiNCVK1es5+zllfHnuRyNOcaj53M0FxYXFzMePZyj35ltbW2MPy9gay6UmA8Bb+MTFxfH+hAAAAAAAOBULt/DAQAAAAAAfP9QcAAAAAAAAE5HwQEAAAAAADgdBQcAAAAAAOB0FBwAAAAAAIDTUXAAAAAAAABOR8EBAAAAAAA4HQUHAAAAAADgdBQcAAAAAACA0/0f9hTNtmMuYF0AAAAASUVORK5CYII=\n" | |
}, | |
"metadata": {} | |
}, | |
{ | |
"output_type": "execute_result", | |
"execution_count": 165, | |
"data": { | |
"text/plain": "((#159) [0,1,2,3,4,5,6,7,8,9...],\n (#40) [159,160,161,162,163,164,165,166,167,168...])" | |
}, | |
"metadata": {} | |
} | |
] | |
}, | |
{ | |
"metadata": { | |
"ExecuteTime": { | |
"start_time": "2021-04-08T15:25:16.634742Z", | |
"end_time": "2021-04-08T15:25:16.929845Z" | |
}, | |
"trusted": true | |
}, | |
"cell_type": "code", | |
"source": "tfms = [None, TSClassification()]\ndls = get_ts_dls(X, y, splits=splits, tfms=tfms)\nxb,yb = dls.train.one_batch()\nxb, yb", | |
"execution_count": 167, | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"execution_count": 167, | |
"data": { | |
"text/plain": "(TSTensor(samples:64, vars:2, len:10),\n TensorCategory([7, 0, 1, 2, 0, 1, 9, 5, 9, 8, 2, 0, 6, 1, 7, 0, 6, 3, 6, 6, 7, 4, 4, 2,\n 4, 6, 5, 1, 1, 0, 0, 2, 1, 2, 8, 3, 5, 0, 1, 0, 8, 2, 2, 0, 1, 9, 1, 2,\n 4, 0, 6, 3, 3, 1, 5, 1, 0, 2, 1, 1, 2, 2, 1, 2]))" | |
}, | |
"metadata": {} | |
} | |
] | |
}, | |
{ | |
"metadata": {}, | |
"cell_type": "markdown", | |
"source": "## Creating TSUnwindowedDatasets" | |
}, | |
{ | |
"metadata": {}, | |
"cell_type": "markdown", | |
"source": "Sometimes your dataframe may be too big, and using SlidingWindow would create an even bigger file (as there may be some overlap in the data). In these cases, you may want to create pull create batches \"on the fly\". \nTo achieve that, we have included TSUnwindowedDataset and TSUnwindowedDatasets. This is how you can use them: " | |
}, | |
{ | |
"metadata": { | |
"ExecuteTime": { | |
"start_time": "2021-04-08T16:37:22.002909Z", | |
"end_time": "2021-04-08T16:37:22.049274Z" | |
}, | |
"trusted": true | |
}, | |
"cell_type": "code", | |
"source": "x = np.arange(0, 2_000, 2).reshape(-1,1)\ny = np.arange(1, 2_000, 2).reshape(-1,1)\nz = np.arange(1, 2_000, 2).reshape(-1,1)\ntarget = np.random.randint(0, 10, 1_000).reshape(-1,1)\ndata = np.concatenate([x, y, z, target], -1)\ndf = pd.DataFrame(data, columns=['x', 'y', 'z', 'target'])\ndf.head()\n\nX, y = df2xy(df, data_cols=['x', 'y'], target_col='target', to3d=False)\nX.shape, y.shape", | |
"execution_count": 220, | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"execution_count": 220, | |
"data": { | |
"text/plain": "((1000, 2), (1000,))" | |
}, | |
"metadata": {} | |
} | |
] | |
}, | |
{ | |
"metadata": {}, | |
"cell_type": "markdown", | |
"source": "Initially we don't know what the number of samples in the dataset. Thus we cannot pass splits. " | |
}, | |
{ | |
"metadata": { | |
"ExecuteTime": { | |
"start_time": "2021-04-08T16:37:38.833527Z", | |
"end_time": "2021-04-08T16:37:38.841648Z" | |
}, | |
"trusted": true | |
}, | |
"cell_type": "code", | |
"source": "def y_func(o): \n return np.max(o, axis=1).ravel()", | |
"execution_count": 223, | |
"outputs": [] | |
}, | |
{ | |
"metadata": { | |
"ExecuteTime": { | |
"start_time": "2021-04-08T16:37:39.543525Z", | |
"end_time": "2021-04-08T16:37:39.560229Z" | |
}, | |
"trusted": true | |
}, | |
"cell_type": "code", | |
"source": "window_size = 10\nstride = 5\ndrop_start = 0\ndrop_end = 0\nseq_first = True\ndset = TSUnwindowedDataset(X, y, y_func=y_func, window_size=window_size, stride=stride, drop_start=drop_start, drop_end=drop_end, seq_first=seq_first)\nlen(dset)", | |
"execution_count": 224, | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"execution_count": 224, | |
"data": { | |
"text/plain": "199" | |
}, | |
"metadata": {} | |
} | |
] | |
}, | |
{ | |
"metadata": { | |
"ExecuteTime": { | |
"start_time": "2021-04-08T16:39:52.676637Z", | |
"end_time": "2021-04-08T16:39:52.699860Z" | |
} | |
}, | |
"cell_type": "markdown", | |
"source": "When using TSUnwindowedDataset, the windowed data is created on the fly. That means, we don't need to store windowed data (that is usually larger).\n\nThe downside to this approach is that it batch creation will be a somewhat slower. In any case, you can test both approaches and see what works best for you." | |
}, | |
{ | |
"metadata": {}, | |
"cell_type": "markdown", | |
"source": "No we can create the splits and a TSUnwindowedDatasets object:" | |
}, | |
{ | |
"metadata": { | |
"ExecuteTime": { | |
"start_time": "2021-04-08T16:37:41.943958Z", | |
"end_time": "2021-04-08T16:37:42.339399Z" | |
}, | |
"trusted": true | |
}, | |
"cell_type": "code", | |
"source": "splits = TimeSplitter(valid_size=.2)(range_of(len(dset)))\nsplits", | |
"execution_count": 225, | |
"outputs": [ | |
{ | |
"output_type": "display_data", | |
"data": { | |
"text/plain": "<Figure size 1152x36 with 1 Axes>", | |
"image/png": "iVBORw0KGgoAAAANSUhEUgAABBwAAABTCAYAAAA82hSvAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8vihELAAAACXBIWXMAAAsTAAALEwEAmpwYAAAVd0lEQVR4nO3de1CU1xnH8R/sglxVboKKV5AoKElBjZQgY60jGjTNGBvtxKozHQfb/pGgjeNQpzYT00SnNtGxwUxn2tRMx0ujtSEQdayxSVS8EY1aIl4RBESEgC4gC9s/Mu6I7K6Iu+xu8v38xfvuu+c8u88czvBw3vP6xMXFWQQAAAAAAOBEvu4OAAAAAAAAfPdQcAAAAAAAAE5HwQEAAAAAADgdBQcAAAAAAOB0FBwAAAAAAIDTUXAAAAAAAABOR8EBAOCRfHx8lJmZqVWrVumdd97RG2+8oUWLFmnAgAHdbiM7O1t5eXmSpLS0NK1fv77H8aSlpWnNmjWSpISEBOXn5yswMLDH7d0vPz9fTz75pCRpzZo1SktLc0q799r70Y9+5LT2AAAAusvo7gAAALDlxz/+saZOnart27fr+vXr6t+/v6ZPn64VK1bo97//vRobGx+pvZMnT6q0tNR6nJ+fr3fffVenTp165NguXbqkvLw8tbS0PPTahQsXKjAwUPn5+XavycvLe+TP093+1q1b1604AQAAnI0VDgAAjzR58mR9/PHHOnnypKqrq1VaWqpNmzbJYrEoJSXlkdtrbW1VfX29U2Izm82qq6uTxWJ5rHZ8fb+dhuvq6tTW1uaM0LpoaGig4AAAANyCFQ4AAI8UFBSkiIiITufMZrM2btyoO3fuSPr2v/lGo1H19fVKT09Xa2urDh48qD179nRpLy0tTXPnzlVubq71v/9Lly5VQUGBCgoKulz/xBNP6MUXX1RUVJTKy8t1/vx562sJCQnKzc3VK6+8oubmZk2YMEEzZ85UZGSkGhoaVFRUpEOHDmnhwoXW2yPWrFmjvLw85ebmqry8XOHh4UpISNDy5cu7rLaIjIzUK6+8ohEjRqi2tla7d+/W6dOnre3s379f//nPfyRJERERWrNmjV5//XVNnTq1S3/3X+/j46Ps7Gylp6erT58+unr1qj788ENdu3ZNkpSbm6srV64oKChIqampam9v1/79+1VUVNTzRAIAgO8tVjgAADzS8ePHlZWVpWXLlikrK0vx8fEyGo26evWqbt68ab3uBz/4gfz9/bVu3Trt3LlTM2bM0DPPPOOw7Xv7OnzwwQfav39/l9dDQ0O1dOlSnT9/Xm+++aYOHz6sqVOn2mxr4MCBWrRokYqKivT6669r3759eumllzRkyBD985//1MmTJ3Xu3DmtW7fO+p7MzEyVl5d3One/adOm6eTJk3rrrbd09uxZ5eTkaNCgQQ/9zuz1d8/MmTOVkZGhrVu3au3atbp8+bKWLVumsLAw6zVTpkxRQ0OD1q5dqwMHDui5555TTEzMQ/sGAAB4ECscAAAeadu2baqurtb48eM1a9YsGQwGtbS06PDhw9qxY4c6OjokSY2Njdq2bZssFouqq6s1YsQIZWZm6vPPP7fbdl1dnSSpqalJzc3NXV7PyMhQfX29tm7dKkmqrKxUbGysxo0b1+Xae5tYXrlyRTdu3FBNTY1MJpPu3r2rO3fuqLW1Vb6+vmpoaLC+58KFC/rkk0/sxvf555/r4MGDkqSdO3cqMTFRzzzzjLZv3+7wO7PXnyQZjUZNmzZN27Zt05dffilJ2r17t5544glNmTJFO3futH6Oeys+CgsLlZWVpcGDB6u6utph3wAAAA+i4AAA8EgdHR06cOCADhw4ID8/P40cOVITJ05UZmammpqaVFhYKEkqLy/vtJdCRUWF0tPTH6vv2NhYXbp0qdO5q1ev2iw4lJaW6sKFC1q1apXKysp0/vx5lZSUqKamxm77VVVVDvu/cuVKl+PIyMjufwAbIiMjFRAQ0OnWEOnb7ys6OrrT8T0Wi0VtbW3y9/d/rL4BAMD3EwUHAIDHGTlypCZPnqy//e1vkqS2tjZ9/fXX+vrrr2UwGDR69GhrwcGWx93M0WAwdDl3b4PHB7W2tupPf/qThgwZojFjxigxMVGzZ8/We++9Z11J8KjxPfi6wWDQ3bt3bV5rNHZvKvfz85Mktbe3dzrv7+/fqe17K0cAAAAeF3s4AAA8jtls1qRJk2zuW3D37t1OfyAPHjy40+sjR4586AqCh6mqqtKIESO6tGtLamqqsrKydO3aNe3du1dvv/22zp49q6eeeqrH/Q8fPrzLcWVlpaRvv5t7xQNJ3d5foba2Vu3t7V0+V3x8fKdVDQAAAM5CwQEA4HHKy8v11VdfKScnRykpKYqOjlZ8fLyys7OVlpam//73v9ZrIyIi9Pzzz2vgwIHKyMjQpEmTrPsfONLW1qbBgwcrKCioy2sHDx5UZGSk5s6dq8GDBysjI8NuAcFkMunZZ59Venq6YmJilJycrOHDh+vy5cvWfsLCwh7plogf/vCHevrppzVo0CDNmTNHYWFh+uyzzyR9WwxJTk6Wv7+/QkJCNH369C6fy1Z/LS0t+uKLL/TCCy9o7NixGjJkiH72s58pMDDQ2jYAAIAzcUsFAMAjbd68WVOnTtWzzz6ryMhItbS06PLly9qwYYPKysqs1507d07BwcFasWKFTCaTdu3apeLi4oe2f+jQIc2YMUNms1l79+7t9Fp9fb3effddvfjii5o8ebLKysr0r3/9S1lZWV3a+d///qddu3Zp+vTp6t+/v5qamnTw4EFrUeTEiRNKSUnRL3/5S7322mvd+uyFhYWaPHmyhg4dqpqaGm3atEm3b9+W9O0mkosXL9a6det069YtFRUVdVp94ai/HTt2SJIWL14sPz8/Xb16VRs2bJDJZOpWXAAAAI/CJy4u7vFudAUAwE0WLlyowMBA5efnuzsUAAAAPIBbKgAAAAAAgNNRcAAAAAAAAE7HLRUAAAAAAMDpWOEAAAAAAACcrteeUhEQEKDY2Fg1NTWpvb29t7oFAAAAAPQCg8Gg0NBQVVRUqKWlxd3hwAP0WsEhNjZWU6ZM6a3uAAAAAABucODAAV24cMHdYcAD9FrBoampSZK0dWu2amsjeqtbAAAAwKON/uAld4cAOEXonVBNOjPJ+rcf0GsFh3u3UdTWRqiyMqa3ugUAAAA8WnTfeneHADgVt9DjHjaNBAAAAAAATkfBAQAAAAAAOF2v3VIBAAAAAIA3MRqNCgwMdHcYHqm5uVlms9nhNaxwAAAAAADgAbGxsYqI4IEH9kRERCg2NtbhNaxwAAAAAADgPkajUW1tbaqpqXF3KB6rqalJMTExMhqNdlc6sMIBAAAAAID7BAYGymQyuTsMj2cymRzectLtgsOiRYuUnp7ulKAAAAAAAIB3s1gsDl9/6C0VSUlJSkpK0sSJE1VWVua0wAAAAAAAwHfXQwsOw4YNk9FoVGNjY2/EAwAAAACAR7p48YLL+4iLi3d5H73loQWHwsJCSVJMTIzLgwEAAAAAAN3z8ssva9SoUZIkg8Ggjo4O620OR44c0ZYtW7rVzqhRo7Ro0SLl5eU5NT6XPKUiOztb2dnZnc7duXNHpaWlrugOAAAAAIDvnbffftv6c25urs6fP6+CgoIu1/n6+qqjo8NuO2VlZU4vNkguKjgUFBR0+ZCRkZF6/vnnXdEdAAAAAAC4T1pamiZNmqSGhgYNGzZMq1ev1vjx4zV79mz1799fdXV12r17t7788kslJCRo8eLFWrlypbKzsxUdHS0fHx+NGTNGJpNJf/3rX3Xp0qVHjoHHYgIAAAAA8B00atQolZaW6rXXXpOfn59+/vOf6/3339fLL7+svXv3asGCBTbfl5KSoqNHj2rFihUqLS3V7Nmze9Q/BQcAAAAAAL6DampqdPjwYeveDuvWrdPFixcVEhIiHx8fBQcHy9e3a1mgtLRUp0+fltlsVklJicLDw3vUv0tuqQAAAAAAAO51584d688Wi0WZmZkaO3asbt26pZqaGrvvu337tvXnjo4OGQyGHvXf7YLD+vXre9QBAAAAAABwr6efflpDhw7Vb3/7W5nNZsXGxiotLc2lfbLCAQAAAACA7ziDwSBfX1/5+fmpf//+mjVrliTJaHRdWYCCAwAAAAAA3RAXF+/uEHrsyJEjGjt2rN566y3V1tZq586d6tevn5YsWaK9e/e6pE+fuLg4i0tafsC9x2Ju3LhQlZUxvdElAAAA4PFSTqS6OwTAKcIawzSteJp27dqlmzdvujucxxIaGipJampqcnMknu1h3xNPqQAAAAAAAE5HwQEAAAAAADgdBQcAAAAAAOB0FBwAAAAAAIDT9dpTKgwGgyQpKqqut7oEAAAAPF5YY5i7QwCcIvTOtxsI3vvbD+i1gsO93SvnzSvorS4BAAAAz1c8zd0RAE4VGhqqmpoad4cBD9BrBYeKigoNHz5cGzZsUHt7e291CydauXKl/vCHP7g7DDwGcujdyJ93I3/ejfx5N/Ln3cif9zAYDAoNDVVFRYW7Q4GH6LWCQ0tLiyIiIqh0ebHg4GCvf57u9x059G7kz7uRP+9G/rwb+fNu5M+7fNf/3uu3vZ/L+/jmp9+4vI/ewqaRAAAAAAB4oZkzZ+p3v/tdl/OpqanasGGDAgIC7L43NzdX6enpkqRNmzYpMjLS5nVr1qxRQkJCj+Kj4AAAAAAAgBcqLi5WdHS0Bg0a1Ol8amqqTp06pZaWlm6186tf/colK4l67ZYKAAAAAADgPHV1dbp48aLGjx+vf//735Ikf39/jR07Vps3b1Z4eLgWLFigkSNHqrW1VSdPntT27dvV0dHRqZ38/HytWrVKtbW1Sk1N1Zw5cxQUFKTi4mL5+Pj0OL5eXeFQUMATKrwZ+fN+5NC7kT/vRv68G/nzbuTPu5E/wLHi4mKlpqZaj5OTk9Xc3Kxz587pueee0/Xr17Vs2TK9+eabSk5O1rhx4+y2FRYWpgULFugf//iHXn31VZlMJoWHh/c4NgoO6Dby5/3IoXcjf96N/Hk38ufdyJ93I3+AYydOnFB4eLiGDBkiSUpJSdHRo0dlsVj0ySef6KOPPpLBYFBwcLDMZrNCQkLstjVhwgSdPXtWZ86c0d27d/XRRx+pubm5x7FxSwUAAAAAAF6qublZp0+fVmpqqm7cuKGxY8dq7dq1kqTBgwdr6dKlam9vV2Vl5UNvj4iIiFBdXZ31uKOjQ01NTT2OjYIDAAAAAABerLi4WHPnzlVFRYVu3LihiooK+fn5aeHChVq/fr0uX74sScrLy3PYTmNjY6cNKI1Go/r27dvjuHhKBQAAAAAAXuzMmTMKCAhQdna2iouLJUm+vr7y9fWVn5+fAgIClJmZqUGDBslotL/u4MSJE0pKSlJiYqL8/f01e/Zs+fv79zguVjgAAAAAANAN3/z0G3eHYFNHR4eOHz+uyZMnWwsOra2t2rZtm5YsWSJJOnTokHbt2qU5c+bo1KlTNtuprq7Wli1bNH/+fIWEhOizzz5TZWVlj+PyiYuLs/T43QAAAAAAfMeEhoZK0mPtX/B98LDvqVdWOMTHx2v+/PmKiorS1atXtWXLFt24caM3ukYPJSYmas6cOYqKitKtW7f08ccf69ixY8rJyVFSUpL1utu3b2vlypVujBS22MsTY9HzzZgxQzNnzux0zsfHR0eOHFFISAjjz8MtWrRIZWVl+uKLLyQ5nv8Yj57nwfzZmwsl+79n4T4P5s9Rjhh/nuf+/DmaCz/44APGH+BFXF5wCAgIUE5Ojj788EOVlJRo2rRp+sUvfqE33njD1V2jh4KDg7VkyRJt375dx44d0+jRo7VkyRJdv35d0dHRWr16daedS+F5bOWJsegdioqKVFRUZD3u27evfvOb32jfvn3Kyclh/HmopKQkJSUlaeLEiSorK5PkeMwxHj2Lrfw5mgsrKyuZDz2IrfxJtudCifnQ09jKn6O5ULKfWwCex+WbRj755JO6efOmDh8+rJaWFhUWFiomJkYDBw50ddfooVGjRqmurk6HDh1SW1ubvvrqK12/fl1jxoxRv379dOvWLXeHiIewlSfGond66aWX9Omnn6qmpobx58GGDRsmo9GoxsZG6zlHY47x6Fls5c/eXDh69GhJtn/Pwj1s5U+ynyPGn2exl7/73T8XSow/wJu4fIVDbGysysvLrcft7e2qqanRgAEDVFVV5eru0QMXLlzQX/7yF+txcHCwIiMj1dTUJIvFouXLl2vgwIGqrq7Wjh07rI9YgWcICwuzmSfGovcZN26cBg4cqM2bN9vNK+PPMxQWFkqSYmJirOccjTnGo2exlT97c2F9fT3j0cPYyp+jHDH+PIut/N3v/rlQcpxbAL3Px8dHFov9bSFdvsIhMDBQJpOp07mWlhYFBAS4umv00O3bt60Tbnx8vF599VVdu3ZNVVVVun79unbs2KEVK1bo6NGj+vWvf23dKASeISQkxGaegoODGYteZsaMGSooKFB7e7vdvDL+PJej+Y+50fPZmwtLSkoYj17AUY4Yf97l/rlQcpxbwJmam5sVFBTk7jA8XlBQkJqbm+2+7vIVDiaTqctzO/v06dPlFz08S0BAgObNm6ennnpKe/bs0Z49e9TR0aE//vGP1ms+/fRTZWRkKD4+XiUlJW6MFve7du2a3TydOXOm07WMRc81fPhwRUdH6/jx45Ic55Xx55kczX/Mjd7B1lxosVgYj17AUY4Yf97jwblQYj5E7zGbzfLz81NMTIxMJpPD/+J/H/n4+CgoKEhGo1Fms9nudS5f4VBVVaXY2FjrscFgUFRUlK5du+bqrtFDfn5+Wr58ufr27avVq1erqKhIHR0dSkxMVHJycqdrDQaDWltb3RQpbLGXp3379jEWvUhGRoaOHz9u/Y8O48/7OJr/mBs9n725UGI8egNHOWL8eY8H50KJ8YfeVVFRoZs3b1JssMFisejmzZuqqKhweJ3LVziUlJTohRdeUHJyskpLSzVr1ixduXJFDQ0Nru4aPTRhwgQZjUb9+c9/7lSt6tOnj+bNm6eGhgZVVVUpPT1d/v7+nXaEhvvZy9OxY8f0k5/8hLHoJUaPHq2tW7dajxl/3sfR/Mfc6PnszYUS49EbOMqRwWBg/HmJB+dCifGH3mc2m9XU1OTuMLyWywsOLS0teu+99zR//nyFh4fr4sWLev/9913dLR7D0KFDNWDAAL3zzjudzv/973/X/v37lZOTo8DAQJWXl2vjxo1qa2tzU6SwpaSkRFFRUV3yxFj0HuHh4YqIiNCVK1es5+zllfHnuRyNOcaj53M0FxYXFzMePZyj35ltbW2MPy9gay6UmA8Bb+MTFxfH+hAAAAAAAOBULt/DAQAAAAAAfP9QcAAAAAAAAE5HwQEAAAAAADgdBQcAAAAAAOB0FBwAAAAAAIDTUXAAAAAAAABOR8EBAAAAAAA4HQUHAAAAAADgdBQcAAAAAACA0/0f9hTNtmMuYF0AAAAASUVORK5CYII=\n" | |
}, | |
"metadata": {} | |
}, | |
{ | |
"output_type": "execute_result", | |
"execution_count": 225, | |
"data": { | |
"text/plain": "((#159) [0,1,2,3,4,5,6,7,8,9...],\n (#40) [159,160,161,162,163,164,165,166,167,168...])" | |
}, | |
"metadata": {} | |
} | |
] | |
}, | |
{ | |
"metadata": { | |
"ExecuteTime": { | |
"start_time": "2021-04-08T16:37:42.951122Z", | |
"end_time": "2021-04-08T16:37:42.966333Z" | |
}, | |
"trusted": true | |
}, | |
"cell_type": "code", | |
"source": "dsets = TSUnwindowedDatasets(dset, splits=splits)\nlen(dsets.train), len(dsets.valid)", | |
"execution_count": 226, | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"execution_count": 226, | |
"data": { | |
"text/plain": "(159, 40)" | |
}, | |
"metadata": {} | |
} | |
] | |
}, | |
{ | |
"metadata": { | |
"ExecuteTime": { | |
"start_time": "2021-04-08T16:37:44.095374Z", | |
"end_time": "2021-04-08T16:37:44.116917Z" | |
}, | |
"trusted": true | |
}, | |
"cell_type": "code", | |
"source": "dls = TSDataLoaders.from_dsets(dsets.train, dsets.valid)\nxb, yb = dls.train.one_batch()\nxb, yb", | |
"execution_count": 227, | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"execution_count": 227, | |
"data": { | |
"text/plain": "(TSTensor(samples:64, vars:2, len:10), TSLabelTensor(shape:(64,)))" | |
}, | |
"metadata": {} | |
} | |
] | |
}, | |
{ | |
"metadata": { | |
"trusted": true | |
}, | |
"cell_type": "code", | |
"source": "", | |
"execution_count": null, | |
"outputs": [] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"name": "python3", | |
"display_name": "Python 3", | |
"language": "python" | |
}, | |
"toc": { | |
"nav_menu": {}, | |
"number_sections": true, | |
"sideBar": true, | |
"skip_h1_title": false, | |
"base_numbering": 1, | |
"title_cell": "Table of Contents", | |
"title_sidebar": "Contents", | |
"toc_cell": false, | |
"toc_position": {}, | |
"toc_section_display": true, | |
"toc_window_display": false | |
}, | |
"language_info": { | |
"name": "python", | |
"version": "3.7.6", | |
"mimetype": "text/x-python", | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"pygments_lexer": "ipython3", | |
"nbconvert_exporter": "python", | |
"file_extension": ".py" | |
}, | |
"gist": { | |
"id": "", | |
"data": { | |
"description": "Preparing data with tsai", | |
"public": true | |
} | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 4 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment