Created
August 24, 2017 07:08
-
-
Save geffy/03d5f9a2cefc7b8feff993435fd3b139 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "code", | |
"execution_count": 2, | |
"metadata": { | |
"collapsed": false, | |
"deletable": true, | |
"editable": true | |
}, | |
"outputs": [], | |
"source": [ | |
"import numpy as np\n", | |
"from scipy.sparse import rand as sprand\n", | |
"import torch\n", | |
"from torch import nn\n", | |
"from torch.utils.data import TensorDataset, DataLoader\n", | |
"from torch.autograd import Variable\n", | |
"import pickle\n", | |
"import tqdm\n", | |
"from sklearn.metrics import mean_squared_error" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 3, | |
"metadata": { | |
"collapsed": true, | |
"deletable": true, | |
"editable": true | |
}, | |
"outputs": [], | |
"source": [] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 4, | |
"metadata": { | |
"collapsed": false, | |
"deletable": true, | |
"editable": true | |
}, | |
"outputs": [], | |
"source": [ | |
"data = pickle.load(open('../data/processed/ml-1M.pickled', 'rb'))\n", | |
"\n", | |
"tr_ds = TensorDataset(torch.from_numpy(data['X_tr'].astype(np.long)), \n", | |
" torch.from_numpy(data['y_tr'].astype(np.float32)))\n", | |
"tr_iter = DataLoader(tr_ds, batch_size=1024, shuffle=True)\n", | |
"\n", | |
"te_ds = TensorDataset(torch.from_numpy(data['X_te'].astype(np.long)), \n", | |
" torch.from_numpy(data['y_te'].astype(np.float32)))\n", | |
"te_iter = DataLoader(te_ds, batch_size=1024, shuffle=False)\n", | |
"\n", | |
"n_users, n_items = np.max(data['X_tr'], axis=0)\n", | |
"n_users = int(n_users) + 1\n", | |
"n_items = int(n_items) + 1" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 5, | |
"metadata": { | |
"collapsed": false, | |
"deletable": true, | |
"editable": true | |
}, | |
"outputs": [], | |
"source": [] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 6, | |
"metadata": { | |
"collapsed": true, | |
"deletable": true, | |
"editable": true | |
}, | |
"outputs": [], | |
"source": [ | |
"# just Embedding layer with custom normal initialization\n", | |
"class ScaledEmbedding(nn.Embedding):\n", | |
" def reset_parameters(self):\n", | |
" self.weight.data.normal_(0, 0.1 / self.embedding_dim)\n", | |
" if self.padding_idx is not None:\n", | |
" self.weight.data[self.padding_idx].fill_(0)\n", | |
" \n", | |
"\n", | |
"class MatrixFactorization(torch.nn.Module):\n", | |
" \n", | |
" def __init__(self, n_users, n_items, n_factors=20):\n", | |
" super().__init__()\n", | |
" self.user_factors = ScaledEmbedding(n_users, n_factors)\n", | |
" self.item_factors = ScaledEmbedding(n_items, n_factors)\n", | |
" \n", | |
" def forward(self, user, item):\n", | |
" return (self.user_factors(user) * self.item_factors(item)).sum(1)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 7, | |
"metadata": { | |
"collapsed": true, | |
"deletable": true, | |
"editable": true | |
}, | |
"outputs": [], | |
"source": [ | |
"from torch.optim.lr_scheduler import ReduceLROnPlateau\n", | |
"\n", | |
"model = MatrixFactorization(n_users, n_items, n_factors=60)\n", | |
"mse = torch.nn.MSELoss()\n", | |
"optimizer = torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-9)\n", | |
"scheduler = ReduceLROnPlateau(optimizer, factor=0.1, patience=3, threshold=3e-4, threshold_mode='abs', verbose=True)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 10, | |
"metadata": { | |
"collapsed": true, | |
"deletable": true, | |
"editable": true | |
}, | |
"outputs": [], | |
"source": [ | |
"# Train" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 11, | |
"metadata": { | |
"collapsed": false, | |
"deletable": true, | |
"editable": true | |
}, | |
"outputs": [], | |
"source": [ | |
"def train():\n", | |
" loss_hist = []\n", | |
" penalty_hist = []\n", | |
" for batch_X, batch_y in tqdm.tqdm(tr_iter):\n", | |
" optimizer.zero_grad()\n", | |
" \n", | |
" bX = Variable(batch_X)\n", | |
" bY = Variable(batch_y.float())\n", | |
" prediction = model.forward(bX[:, 0], bX[:, 1])\n", | |
" loss = mse(prediction, bY)\n", | |
"\n", | |
" loss_hist.append(loss.data.numpy()[0])\n", | |
"\n", | |
" # Backpropagate\n", | |
" loss.backward()\n", | |
"\n", | |
" # Update the parameters\n", | |
" optimizer.step()\n", | |
"\n", | |
" print('mse: {}| rmse: {}'.format(np.mean(loss_hist), np.sqrt(np.mean(loss_hist))))\n", | |
" return loss_hist" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 12, | |
"metadata": { | |
"collapsed": false, | |
"deletable": true, | |
"editable": true | |
}, | |
"outputs": [], | |
"source": [ | |
"# Test" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 13, | |
"metadata": { | |
"collapsed": false, | |
"deletable": true, | |
"editable": true | |
}, | |
"outputs": [], | |
"source": [ | |
"def test():\n", | |
" pred_hist = []\n", | |
" for batch_X, batch_y in tqdm.tqdm(te_iter):\n", | |
" bX = Variable(batch_X)\n", | |
" bY = Variable(batch_y.float())\n", | |
" prediction = model.forward(bX[:, 0], bX[:, 1])\n", | |
" pred_hist.append(prediction.data.numpy())\n", | |
" return np.concatenate(pred_hist) " | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"collapsed": false, | |
"deletable": true, | |
"editable": true, | |
"scrolled": false | |
}, | |
"outputs": [ | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"100%|██████████| 880/880 [00:11<00:00, 74.39it/s]\n", | |
" 14%|█▍ | 14/98 [00:00<00:00, 135.49it/s]" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"mse: 4.653479099273682| rmse: 2.1571924686431885\n" | |
] | |
}, | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"100%|██████████| 98/98 [00:00<00:00, 177.92it/s]\n", | |
" 1%| | 5/880 [00:00<00:18, 47.60it/s]" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"0 1.01173530424\n" | |
] | |
}, | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"100%|██████████| 880/880 [00:12<00:00, 72.27it/s]\n", | |
" 22%|██▏ | 22/98 [00:00<00:00, 219.62it/s]" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"mse: 0.9061034917831421| rmse: 0.9518947005271912\n" | |
] | |
}, | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"100%|██████████| 98/98 [00:00<00:00, 165.33it/s]\n", | |
" 0%| | 2/880 [00:00<00:51, 17.15it/s]" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"1 0.93067693311\n" | |
] | |
}, | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"100%|██████████| 880/880 [00:12<00:00, 70.76it/s]\n", | |
" 21%|██▏ | 21/98 [00:00<00:00, 161.11it/s]" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"mse: 0.851354718208313| rmse: 0.9226888418197632\n" | |
] | |
}, | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"100%|██████████| 98/98 [00:00<00:00, 187.60it/s]\n", | |
" 1%| | 5/880 [00:00<00:23, 38.01it/s]" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"2 0.9221900053\n" | |
] | |
}, | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"100%|██████████| 880/880 [00:11<00:00, 77.60it/s]\n", | |
" 14%|█▍ | 14/98 [00:00<00:00, 136.96it/s]" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"mse: 0.8407105207443237| rmse: 0.9169026613235474\n" | |
] | |
}, | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"100%|██████████| 98/98 [00:00<00:00, 175.47it/s]\n", | |
" 1%| | 5/880 [00:00<00:18, 47.06it/s]" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"3 0.917020341644\n" | |
] | |
}, | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"100%|██████████| 880/880 [00:11<00:00, 76.72it/s]\n", | |
" 13%|█▎ | 13/98 [00:00<00:00, 128.18it/s]" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"mse: 0.8308556079864502| rmse: 0.9115127921104431\n" | |
] | |
}, | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"100%|██████████| 98/98 [00:00<00:00, 173.40it/s]\n", | |
" 1%| | 5/880 [00:00<00:18, 47.75it/s]" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"4 0.911672158106\n" | |
] | |
}, | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"100%|██████████| 880/880 [00:11<00:00, 78.59it/s]\n", | |
" 21%|██▏ | 21/98 [00:00<00:00, 202.82it/s]" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"mse: 0.8189589977264404| rmse: 0.9049635529518127\n" | |
] | |
}, | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"100%|██████████| 98/98 [00:00<00:00, 191.10it/s]\n", | |
" 1%| | 5/880 [00:00<00:18, 46.53it/s]" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"5 0.907034108872\n" | |
] | |
}, | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"100%|██████████| 880/880 [00:11<00:00, 77.61it/s]\n", | |
" 17%|█▋ | 17/98 [00:00<00:00, 141.66it/s]" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"mse: 0.8076307773590088| rmse: 0.8986827731132507\n" | |
] | |
}, | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"100%|██████████| 98/98 [00:00<00:00, 173.19it/s]\n", | |
" 1%| | 5/880 [00:00<00:18, 48.50it/s]" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"6 0.902936388422\n" | |
] | |
}, | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"100%|██████████| 880/880 [00:11<00:00, 76.86it/s]\n", | |
" 19%|█▉ | 19/98 [00:00<00:00, 185.28it/s]" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"mse: 0.7959813475608826| rmse: 0.8921778798103333\n" | |
] | |
}, | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"100%|██████████| 98/98 [00:00<00:00, 172.34it/s]\n", | |
" 1%| | 5/880 [00:00<00:18, 47.96it/s]" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"7 0.898566136779\n" | |
] | |
}, | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"100%|██████████| 880/880 [00:11<00:00, 74.25it/s]\n", | |
" 13%|█▎ | 13/98 [00:00<00:00, 128.63it/s]" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"mse: 0.7819766402244568| rmse: 0.8842944502830505\n" | |
] | |
}, | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"100%|██████████| 98/98 [00:00<00:00, 169.65it/s]\n", | |
" 1%| | 5/880 [00:00<00:18, 47.47it/s]" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"8 0.891737408583\n" | |
] | |
}, | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"100%|██████████| 880/880 [00:11<00:00, 74.63it/s]\n", | |
" 20%|██ | 20/98 [00:00<00:00, 154.51it/s]" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"mse: 0.7655534148216248| rmse: 0.8749591112136841\n" | |
] | |
}, | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"100%|██████████| 98/98 [00:00<00:00, 187.24it/s]\n", | |
" 0%| | 3/880 [00:00<00:31, 27.48it/s]" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"9 0.884484992531\n" | |
] | |
}, | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"100%|██████████| 880/880 [00:12<00:00, 78.19it/s]\n", | |
" 14%|█▍ | 14/98 [00:00<00:00, 138.73it/s]" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"mse: 0.7477543354034424| rmse: 0.8647279143333435\n" | |
] | |
}, | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"100%|██████████| 98/98 [00:00<00:00, 174.74it/s]\n", | |
" 1%| | 5/880 [00:00<00:18, 46.39it/s]" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"10 0.878235862071\n" | |
] | |
}, | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"100%|██████████| 880/880 [00:11<00:00, 77.35it/s]\n", | |
" 13%|█▎ | 13/98 [00:00<00:00, 127.24it/s]" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"mse: 0.7289107441902161| rmse: 0.853762686252594\n" | |
] | |
}, | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"100%|██████████| 98/98 [00:00<00:00, 173.27it/s]\n", | |
" 1%| | 5/880 [00:00<00:18, 48.17it/s]" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"11 0.872880039698\n" | |
] | |
}, | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"100%|██████████| 880/880 [00:11<00:00, 76.73it/s]\n", | |
" 19%|█▉ | 19/98 [00:00<00:00, 189.53it/s]" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"mse: 0.708349883556366| rmse: 0.8416352272033691\n" | |
] | |
}, | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"100%|██████████| 98/98 [00:00<00:00, 184.58it/s]\n", | |
" 1%| | 5/880 [00:00<00:18, 47.70it/s]" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"12 0.866932788908\n" | |
] | |
}, | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"100%|██████████| 880/880 [00:11<00:00, 75.45it/s]\n", | |
" 20%|██ | 20/98 [00:00<00:00, 191.45it/s]" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"mse: 0.6866269707679749| rmse: 0.8286295533180237\n" | |
] | |
}, | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"100%|██████████| 98/98 [00:00<00:00, 184.78it/s]\n", | |
" 0%| | 4/880 [00:00<00:22, 38.22it/s]" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"13 0.861228102471\n" | |
] | |
}, | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"100%|██████████| 880/880 [00:12<00:00, 68.93it/s]\n", | |
" 12%|█▏ | 12/98 [00:00<00:00, 112.82it/s]" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"mse: 0.6639280319213867| rmse: 0.8148177862167358\n" | |
] | |
}, | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"100%|██████████| 98/98 [00:00<00:00, 133.63it/s]\n", | |
" 0%| | 3/880 [00:00<00:32, 26.69it/s]" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"14 0.857208540688\n" | |
] | |
}, | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"100%|██████████| 880/880 [00:12<00:00, 68.46it/s]\n", | |
" 13%|█▎ | 13/98 [00:00<00:00, 128.50it/s]" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"mse: 0.6401769518852234| rmse: 0.8001105785369873\n" | |
] | |
}, | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"100%|██████████| 98/98 [00:00<00:00, 175.45it/s]\n", | |
" 1%| | 5/880 [00:00<00:19, 45.31it/s]" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"15 0.854500042774\n" | |
] | |
}, | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"100%|██████████| 880/880 [00:11<00:00, 76.96it/s]\n", | |
" 11%|█ | 11/98 [00:00<00:00, 109.56it/s]" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"mse: 0.6150626540184021| rmse: 0.7842593193054199\n" | |
] | |
}, | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"100%|██████████| 98/98 [00:00<00:00, 170.92it/s]\n", | |
" 1%| | 5/880 [00:00<00:18, 47.23it/s]" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"16 0.852726459753\n" | |
] | |
}, | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"100%|██████████| 880/880 [00:11<00:00, 74.62it/s]\n", | |
" 20%|██ | 20/98 [00:00<00:00, 196.20it/s]" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"mse: 0.5880409479141235| rmse: 0.7668382525444031\n" | |
] | |
}, | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"100%|██████████| 98/98 [00:00<00:00, 185.30it/s]\n", | |
" 1%| | 5/880 [00:00<00:17, 48.76it/s]" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"17 0.85159162539\n" | |
] | |
}, | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"100%|██████████| 880/880 [00:11<00:00, 77.37it/s]\n", | |
" 19%|█▉ | 19/98 [00:00<00:00, 189.47it/s]" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"mse: 0.5601016879081726| rmse: 0.7483994364738464\n" | |
] | |
}, | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"100%|██████████| 98/98 [00:00<00:00, 183.51it/s]\n", | |
" 1%| | 5/880 [00:00<00:23, 37.87it/s]" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"18 0.851854084252\n" | |
] | |
}, | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"100%|██████████| 880/880 [00:11<00:00, 76.52it/s]\n", | |
" 14%|█▍ | 14/98 [00:00<00:00, 131.52it/s]" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"mse: 0.5313001871109009| rmse: 0.7289034128189087\n" | |
] | |
}, | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"100%|██████████| 98/98 [00:00<00:00, 171.02it/s]\n", | |
" 1%| | 5/880 [00:00<00:18, 48.28it/s]" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"19 0.853770273338\n" | |
] | |
}, | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"100%|██████████| 880/880 [00:11<00:00, 73.87it/s]\n", | |
" 13%|█▎ | 13/98 [00:00<00:00, 128.89it/s]" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"mse: 0.5032721161842346| rmse: 0.7094167470932007\n" | |
] | |
}, | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"100%|██████████| 98/98 [00:00<00:00, 171.69it/s]\n", | |
" 1%| | 5/880 [00:00<00:17, 48.91it/s]" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"20 0.857647674373\n" | |
] | |
}, | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
" 25%|██▍ | 217/880 [00:02<00:08, 73.79it/s]" | |
] | |
} | |
], | |
"source": [ | |
"val_hist = []\n", | |
"for i in range(50):\n", | |
" loss_hist = train()\n", | |
" pred = test()\n", | |
" val_err = np.sqrt(mean_squared_error(data['y_te'], pred))\n", | |
" scheduler.step(val_err)\n", | |
" val_hist.append(val_err)\n", | |
" print(i, val_err)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"collapsed": true, | |
"deletable": true, | |
"editable": true | |
}, | |
"outputs": [], | |
"source": [ | |
"# SVD-20: 0.8623" | |
] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 3", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.6.1" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 2 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment