Last active
January 22, 2023 03:00
-
-
Save reachsumit/c1f7e11c0cfa5f696fd9ccd392f9b8d0 to your computer and use it in GitHub Desktop.
DeepAR PyTorch end-to-end demo
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{"metadata":{"kernelspec":{"language":"python","display_name":"Python 3","name":"python3"},"language_info":{"name":"python","version":"3.7.12","mimetype":"text/x-python","codemirror_mode":{"name":"ipython","version":3},"pygments_lexer":"ipython3","nbconvert_exporter":"python","file_extension":".py"}},"nbformat_minor":4,"nbformat":4,"cells":[{"cell_type":"code","source":"import os\nimport pandas as pd\nimport numpy as np\nfrom tqdm import trange, tqdm\n\nfrom io import BytesIO\nfrom urllib.request import urlopen\nfrom zipfile import ZipFile\n\nfrom pandas import read_csv, DataFrame\nfrom scipy import stats\n\ndef prep_data(data, covariates, data_start, train = True):\n time_len = data.shape[0]\n input_size = window_size-stride_size\n windows_per_series = np.full((num_series), (time_len-input_size) // stride_size)\n if train: windows_per_series -= (data_start+stride_size-1) // stride_size\n total_windows = np.sum(windows_per_series)\n x_input = np.zeros((total_windows, window_size, 1 + num_covariates + 1), dtype='float32')\n label = np.zeros((total_windows, window_size), dtype='float32')\n v_input = np.zeros((total_windows, 2), dtype='float32')\n count = 0\n if not train:\n covariates = covariates[-time_len:]\n for series in trange(num_series):\n cov_age = stats.zscore(np.arange(total_time-data_start[series]))\n if train:\n covariates[data_start[series]:time_len, 0] = cov_age[:time_len-data_start[series]]\n else:\n covariates[:, 0] = cov_age[-time_len:]\n for i in range(windows_per_series[series]):\n if train:\n window_start = stride_size*i+data_start[series]\n else:\n window_start = stride_size*i\n window_end = window_start+window_size\n x_input[count, 1:, 0] = data[window_start:window_end-1, series]\n x_input[count, :, 1:1+num_covariates] = covariates[window_start:window_end, :]\n x_input[count, :, -1] = series\n label[count, :] = data[window_start:window_end, series]\n nonzero_sum = (x_input[count, 1:input_size, 0]!=0).sum()\n if nonzero_sum == 0:\n v_input[count, 0] = 0\n else:\n v_input[count, 0] = np.true_divide(x_input[count, 1:input_size, 0].sum(),nonzero_sum)+1\n x_input[count, :, 0] = x_input[count, :, 0]/v_input[count, 0]\n if train:\n label[count, :] = label[count, :]/v_input[count, 0]\n count += 1\n return x_input, v_input, label\n\ndef gen_covariates(times, num_covariates):\n covariates = np.zeros((times.shape[0], num_covariates))\n for i, input_time in enumerate(times):\n covariates[i, 1] = input_time.weekday()\n covariates[i, 2] = input_time.hour\n covariates[i, 3] = input_time.month\n for i in range(1,num_covariates):\n covariates[:,i] = stats.zscore(covariates[:,i])\n return covariates[:, :num_covariates]","metadata":{"_uuid":"8f2839f25d086af736a60e9eeb907d3b93b6e0e5","_cell_guid":"b1076dfc-b9ad-4769-8c92-a6c4dae69d19","execution":{"iopub.status.busy":"2023-01-09T03:41:15.751483Z","iopub.execute_input":"2023-01-09T03:41:15.751814Z","iopub.status.idle":"2023-01-09T03:41:16.498277Z","shell.execute_reply.started":"2023-01-09T03:41:15.751733Z","shell.execute_reply":"2023-01-09T03:41:16.497172Z"},"trusted":true},"execution_count":1,"outputs":[]},{"cell_type":"code","source":"name = 'LD2011_2014.txt'\nsave_name = 'elect'\nwindow_size = 192\nstride_size = 24\nnum_covariates = 4\ntrain_start = '2011-01-01 00:00:00'\ntrain_end = '2014-08-31 23:00:00'\ntest_start = '2014-08-25 00:00:00' #need additional 7 days as given info\ntest_end = '2014-09-07 23:00:00'\n\nsave_path = os.path.join('data', save_name)\nif not os.path.exists(save_path):\n os.makedirs(save_path)\ncsv_path = os.path.join(save_path, name)\nif not os.path.exists(csv_path):\n zipurl = 'https://archive.ics.uci.edu/ml/machine-learning-databases/00321/LD2011_2014.txt.zip'\n with urlopen(zipurl) as zipresp:\n with ZipFile(BytesIO(zipresp.read())) as zfile:\n zfile.extractall(save_path)\n\ndata_frame = pd.read_csv(csv_path, sep=\";\", index_col=0, parse_dates=True, decimal=',')\ndata_frame = data_frame.resample('1H',label = 'left',closed = 'right').sum()[train_start:test_end]\ndata_frame.fillna(0, inplace=True)\ncovariates = gen_covariates(data_frame[train_start:test_end].index, num_covariates)\ntrain_data = data_frame[train_start:train_end].values\ntest_data = data_frame[test_start:test_end].values\ndata_start = (train_data!=0).argmax(axis=0) #find first nonzero value in each time series\ntotal_time = data_frame.shape[0] #32304\nnum_series = data_frame.shape[1] #370\nX_train, v_train, y_train = prep_data(train_data, covariates, data_start)\nX_test, v_test, y_test = prep_data(test_data, covariates, data_start, train=False)","metadata":{"execution":{"iopub.status.busy":"2023-01-09T03:41:16.504848Z","iopub.execute_input":"2023-01-09T03:41:16.507407Z","iopub.status.idle":"2023-01-09T03:41:43.930075Z","shell.execute_reply.started":"2023-01-09T03:41:16.507359Z","shell.execute_reply":"2023-01-09T03:41:43.929045Z"},"trusted":true},"execution_count":2,"outputs":[{"name":"stderr","text":"100%|██████████| 370/370 [00:09<00:00, 38.24it/s]\n100%|██████████| 370/370 [00:00<00:00, 2027.56it/s]\n","output_type":"stream"}]},{"cell_type":"code","source":"import math\nimport numpy as np\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom torch.autograd import Variable","metadata":{"execution":{"iopub.status.busy":"2023-01-09T03:41:43.931515Z","iopub.execute_input":"2023-01-09T03:41:43.932910Z","iopub.status.idle":"2023-01-09T03:41:45.661285Z","shell.execute_reply.started":"2023-01-09T03:41:43.932863Z","shell.execute_reply":"2023-01-09T03:41:45.660322Z"},"trusted":true},"execution_count":3,"outputs":[]},{"cell_type":"code","source":"class DeepAR(nn.Module):\n def __init__(self,\n num_class=num_series,\n embedding_dim=20,\n cov_dim=num_covariates,\n lstm_hidden_dim=40,\n lstm_layers=3,\n lstm_dropout=0.1,\n sample_times=200,\n predict_start=window_size-stride_size,\n predict_steps=stride_size,\n device=torch.device('cuda')):\n super(DeepAR, self).__init__()\n self.lstm_layers = lstm_layers\n self.lstm_hidden_dim = lstm_hidden_dim\n self.device = device\n self.sample_times = sample_times\n self.predict_steps = predict_steps\n self.predict_start = predict_start\n self.embedding = nn.Embedding(num_class, embedding_dim)\n\n self.lstm = nn.LSTM(input_size=1+cov_dim+embedding_dim,\n hidden_size=lstm_hidden_dim,\n num_layers=lstm_layers,\n bias=True,\n batch_first=False,\n dropout=lstm_dropout)\n\n self.relu = nn.ReLU()\n self.distribution_mu = nn.Linear(lstm_hidden_dim * lstm_layers, 1)\n self.distribution_presigma = nn.Linear(lstm_hidden_dim * lstm_layers, 1)\n self.distribution_sigma = nn.Softplus()\n\n def forward(self, x, idx, hidden, cell):\n onehot_embed = self.embedding(idx) # use an embedding corresponding to time series idx\n lstm_input = torch.cat((x, onehot_embed), dim=2) # concat embedding with the training data\n output, (hidden, cell) = self.lstm(lstm_input, (hidden, cell)) # process through LSTM, batch_first is False by default\n \n # use h from all three layers to calculate mu and sigma\n hidden_permute = hidden.permute(1, 2, 0).contiguous().view(hidden.shape[1], -1)\n \n mu = self.distribution_mu(hidden_permute)\n pre_sigma = self.distribution_presigma(hidden_permute)\n sigma = self.distribution_sigma(pre_sigma) # softplus to make sure standard deviation is positive\n return torch.squeeze(mu), torch.squeeze(sigma), hidden, cell\n\n def init_hidden(self, input_size):\n return torch.zeros(self.lstm_layers, input_size, self.lstm_hidden_dim, device=self.device)\n\n def init_cell(self, input_size):\n return torch.zeros(self.lstm_layers, input_size, self.lstm_hidden_dim, device=self.device)\n \n def test(self, x, v_batch, id_batch, hidden, cell):\n batch_size = x.shape[1]\n samples = torch.zeros(self.sample_times, batch_size, self.predict_steps,\n device=self.device)\n for j in range(self.sample_times):\n decoder_hidden = hidden\n decoder_cell = cell\n for t in range(self.predict_steps):\n mu_de, sigma_de, decoder_hidden, decoder_cell = self(x[self.predict_start + t].unsqueeze(0),\n id_batch, decoder_hidden, decoder_cell)\n gaussian = torch.distributions.normal.Normal(mu_de, sigma_de)\n pred = gaussian.sample() # not scaled\n samples[j, :, t] = pred * v_batch[:, 0] + v_batch[:, 1]\n if t < (self.predict_steps - 1):\n x[self.predict_start + t + 1, :, 0] = pred\n\n sample_mu = torch.median(samples, dim=0)[0]\n sample_sigma = samples.std(dim=0)\n return samples, sample_mu, sample_sigma\n\ndef loss_fn(mu: Variable, sigma: Variable, labels: Variable):\n zero_index = (labels != 0)\n distribution = torch.distributions.normal.Normal(mu[zero_index], sigma[zero_index])\n likelihood = distribution.log_prob(labels[zero_index])\n return -torch.mean(likelihood)","metadata":{"execution":{"iopub.status.busy":"2023-01-09T03:41:45.663859Z","iopub.execute_input":"2023-01-09T03:41:45.664401Z","iopub.status.idle":"2023-01-09T03:41:45.683560Z","shell.execute_reply.started":"2023-01-09T03:41:45.664364Z","shell.execute_reply":"2023-01-09T03:41:45.682673Z"},"trusted":true},"execution_count":4,"outputs":[]},{"cell_type":"code","source":"from torch.utils.data import DataLoader, Dataset, Sampler\nfrom torch.utils.data.sampler import RandomSampler\n\nclass TrainDataset(Dataset):\n def __init__(self, data, label):\n self.data = data\n self.label = label\n self.train_len = self.data.shape[0]\n\n def __len__(self):\n return self.train_len\n\n def __getitem__(self, index):\n return (self.data[index,:,:-1],int(self.data[index,0,-1]), self.label[index])\n\nclass TestDataset(Dataset):\n def __init__(self, data, v, label):\n self.data = data\n self.v = v\n self.label = label\n self.test_len = self.data.shape[0]\n\n def __len__(self):\n return self.test_len\n\n def __getitem__(self, index):\n return (self.data[index,:,:-1],int(self.data[index,0,-1]),self.v[index],self.label[index])\n\nclass WeightedSampler(Sampler):\n def __init__(self, v, replacement=True):\n self.weights = torch.as_tensor(np.abs(v[:,0])/np.sum(np.abs(v[:,0])), dtype=torch.double)\n self.num_samples = self.weights.shape[0]\n self.replacement = replacement\n\n def __iter__(self):\n return iter(torch.multinomial(self.weights, self.num_samples, self.replacement).tolist())\n\n def __len__(self):\n return self.num_samples","metadata":{"execution":{"iopub.status.busy":"2023-01-09T03:41:45.685075Z","iopub.execute_input":"2023-01-09T03:41:45.685441Z","iopub.status.idle":"2023-01-09T03:41:45.700323Z","shell.execute_reply.started":"2023-01-09T03:41:45.685406Z","shell.execute_reply":"2023-01-09T03:41:45.699380Z"},"trusted":true},"execution_count":5,"outputs":[]},{"cell_type":"code","source":"train_set = TrainDataset(data=X_train, label=y_train)\ntest_set = TestDataset(data=X_test, v=v_test, label=y_test)\nsampler = WeightedSampler(v=v_train) # Use weighted sampler instead of random sampler\ntrain_loader = DataLoader(train_set, batch_size=64, sampler=sampler)\ntest_loader = DataLoader(test_set, batch_size=256, sampler=RandomSampler(test_set))","metadata":{"execution":{"iopub.status.busy":"2023-01-09T03:41:45.701746Z","iopub.execute_input":"2023-01-09T03:41:45.702092Z","iopub.status.idle":"2023-01-09T03:41:45.717334Z","shell.execute_reply.started":"2023-01-09T03:41:45.702059Z","shell.execute_reply":"2023-01-09T03:41:45.716412Z"},"trusted":true},"execution_count":6,"outputs":[]},{"cell_type":"code","source":"import torch.optim as optim\n\ndef accuracy_RMSE(mu: torch.Tensor, labels: torch.Tensor, relative = False):\n zero_index = (labels != 0)\n diff = torch.sum(torch.mul((mu[zero_index] - labels[zero_index]), (mu[zero_index] - labels[zero_index]))).item()\n if relative is False:\n return [diff, torch.sum(zero_index).item(), torch.sum(zero_index).item()]\n else:\n summation = torch.sum(torch.abs(labels[zero_index])).item()\n if summation == 0:\n logger.error('summation denominator error! ')\n return [diff, summation, torch.sum(zero_index).item()]\n\n\ndef update_metrics(raw_metrics, input_mu, input_sigma, sample_mu, labels, predict_start, samples=None, relative=False):\n # TODO: use samples to calcualte rou50, rou90 metrics\n raw_metrics['RMSE'] = raw_metrics['RMSE'] + accuracy_RMSE(sample_mu, labels[:, predict_start:], relative=relative)\n input_time_steps = input_mu.numel()\n raw_metrics['test_loss'] = raw_metrics['test_loss'] + [\n loss_fn(input_mu, input_sigma, labels[:, :predict_start]).cpu() * input_time_steps, input_time_steps]\n return raw_metrics\n\n\ndef final_metrics(raw_metrics):\n summary_metric = {}\n summary_metric['RMSE'] = np.sqrt(raw_metrics['RMSE'][0] / raw_metrics['RMSE'][2]) / (\n raw_metrics['RMSE'][1] / raw_metrics['RMSE'][2])\n summary_metric['test_loss'] = (raw_metrics['test_loss'][0] / raw_metrics['test_loss'][1]).item()\n return summary_metric\n\n\ndef train(model, device=torch.device('cuda'), num_epochs = 1, learning_rate = 1e-3):\n train_len = len(train_loader)\n optimizer = optim.Adam(model.parameters(), lr=learning_rate)\n loss_summary = np.zeros((train_len * num_epochs))\n \n for epoch in range(num_epochs):\n model.train()\n loss_epoch = np.zeros(len(train_loader))\n\n for i, (train_batch, idx, labels_batch) in enumerate(tqdm(train_loader)):\n optimizer.zero_grad()\n batch_size = train_batch.shape[0]\n\n train_batch = train_batch.permute(1, 0, 2).to(torch.float32).to(device)\n labels_batch = labels_batch.permute(1, 0).to(torch.float32).to(device)\n idx = idx.unsqueeze(0).to(device)\n\n loss = torch.zeros(1, device=device)\n hidden = model.init_hidden(batch_size)\n cell = model.init_cell(batch_size)\n\n for t in range(window_size):\n # TODO: if z_t is missing, replace it by output mu from the last time step\n mu, sigma, hidden, cell = model(train_batch[t].unsqueeze_(0).clone(), idx, hidden, cell)\n loss += loss_fn(mu, sigma, labels_batch[t])\n\n loss.backward()\n optimizer.step()\n loss = loss.item() / window_size # loss per timestep\n loss_epoch[i] = loss\n \n loss_summary[epoch * train_len:(epoch + 1) * train_len] = loss_epoch\n \n return loss_summary\n\ndef evaluate(model, test_predict_start= window_size-stride_size, device=torch.device('cuda')):\n raw_metrics = {\n 'RMSE': np.zeros(3), # numerator, denominator, time step count\n 'test_loss': np.zeros(2)\n }\n model.eval()\n with torch.no_grad():\n for i, (test_batch, id_batch, v, labels) in enumerate(tqdm(test_loader)):\n test_batch = test_batch.permute(1, 0, 2).to(torch.float32).to(device)\n id_batch = id_batch.unsqueeze(0).to(device)\n v_batch = v.to(torch.float32).to(device)\n labels = labels.to(torch.float32).to(device)\n batch_size = test_batch.shape[1]\n input_mu = torch.zeros(batch_size, test_predict_start, device=device)\n input_sigma = torch.zeros(batch_size, test_predict_start, device=device)\n hidden = model.init_hidden(batch_size)\n cell = model.init_cell(batch_size)\n\n for t in range(test_predict_start):\n # TODO: if z_t is missing, replace it by output mu from the last time step\n mu, sigma, hidden, cell = model(test_batch[t].unsqueeze(0), id_batch, hidden, cell)\n input_mu[:,t] = v_batch[:, 0] * mu + v_batch[:, 1]\n input_sigma[:,t] = v_batch[:, 0] * sigma\n \n # do ancestral sampling\n samples, sample_mu, sample_sigma = model.test(test_batch, v_batch, id_batch, hidden, cell)\n raw_metrics = update_metrics(raw_metrics, input_mu, input_sigma, sample_mu, labels, test_predict_start, samples)\n test_metrics = final_metrics(raw_metrics)\n return test_metrics","metadata":{"execution":{"iopub.status.busy":"2023-01-09T05:02:11.277826Z","iopub.execute_input":"2023-01-09T05:02:11.278179Z","iopub.status.idle":"2023-01-09T05:02:11.299770Z","shell.execute_reply.started":"2023-01-09T05:02:11.278149Z","shell.execute_reply":"2023-01-09T05:02:11.298659Z"},"trusted":true},"execution_count":26,"outputs":[]},{"cell_type":"code","source":"model = DeepAR(device=torch.device(type='cuda')).cuda()\n\nloss_summary = train(model)\ntest_metrics = evaluate(model)","metadata":{"execution":{"iopub.status.busy":"2023-01-09T03:41:45.748988Z","iopub.execute_input":"2023-01-09T03:41:45.749291Z","iopub.status.idle":"2023-01-09T04:29:18.625354Z","shell.execute_reply.started":"2023-01-09T03:41:45.749265Z","shell.execute_reply":"2023-01-09T04:29:18.624527Z"},"trusted":true},"execution_count":8,"outputs":[{"name":"stderr","text":"100%|██████████| 6080/6080 [47:29<00:00, 2.13it/s]\n","output_type":"stream"}]},{"cell_type":"code","source":"loss_summary","metadata":{"execution":{"iopub.status.busy":"2023-01-09T05:04:00.837394Z","iopub.execute_input":"2023-01-09T05:04:00.838268Z","iopub.status.idle":"2023-01-09T05:04:00.846008Z","shell.execute_reply.started":"2023-01-09T05:04:00.838229Z","shell.execute_reply":"2023-01-09T05:04:00.844683Z"},"trusted":true},"execution_count":29,"outputs":[{"execution_count":29,"output_type":"execute_result","data":{"text/plain":"array([ 1.96823676, 1.87354581, 1.77598445, ..., -1.38604752,\n -1.34913286, -1.35332966])"},"metadata":{}}]},{"cell_type":"code","source":"test_metrics","metadata":{"execution":{"iopub.status.busy":"2023-01-09T05:04:00.848773Z","iopub.execute_input":"2023-01-09T05:04:00.849196Z","iopub.status.idle":"2023-01-09T05:04:00.857370Z","shell.execute_reply.started":"2023-01-09T05:04:00.849151Z","shell.execute_reply":"2023-01-09T05:04:00.856114Z"},"trusted":true},"execution_count":30,"outputs":[{"execution_count":30,"output_type":"execute_result","data":{"text/plain":"{'RMSE': 0.48047670502109513, 'test_loss': 5.506244932252884}"},"metadata":{}}]}]} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment