Created
August 24, 2017 07:08
-
-
Save geffy/03d5f9a2cefc7b8feff993435fd3b139 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| { | |
| "cells": [ | |
| { | |
| "cell_type": "code", | |
| "execution_count": 2, | |
| "metadata": { | |
| "collapsed": false, | |
| "deletable": true, | |
| "editable": true | |
| }, | |
| "outputs": [], | |
| "source": [ | |
| "import numpy as np\n", | |
| "from scipy.sparse import rand as sprand\n", | |
| "import torch\n", | |
| "from torch import nn\n", | |
| "from torch.utils.data import TensorDataset, DataLoader\n", | |
| "from torch.autograd import Variable\n", | |
| "import pickle\n", | |
| "import tqdm\n", | |
| "from sklearn.metrics import mean_squared_error" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 3, | |
| "metadata": { | |
| "collapsed": true, | |
| "deletable": true, | |
| "editable": true | |
| }, | |
| "outputs": [], | |
| "source": [] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 4, | |
| "metadata": { | |
| "collapsed": false, | |
| "deletable": true, | |
| "editable": true | |
| }, | |
| "outputs": [], | |
| "source": [ | |
| "data = pickle.load(open('../data/processed/ml-1M.pickled', 'rb'))\n", | |
| "\n", | |
| "tr_ds = TensorDataset(torch.from_numpy(data['X_tr'].astype(np.long)), \n", | |
| " torch.from_numpy(data['y_tr'].astype(np.float32)))\n", | |
| "tr_iter = DataLoader(tr_ds, batch_size=1024, shuffle=True)\n", | |
| "\n", | |
| "te_ds = TensorDataset(torch.from_numpy(data['X_te'].astype(np.long)), \n", | |
| " torch.from_numpy(data['y_te'].astype(np.float32)))\n", | |
| "te_iter = DataLoader(te_ds, batch_size=1024, shuffle=False)\n", | |
| "\n", | |
| "n_users, n_items = np.max(data['X_tr'], axis=0)\n", | |
| "n_users = int(n_users) + 1\n", | |
| "n_items = int(n_items) + 1" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 5, | |
| "metadata": { | |
| "collapsed": false, | |
| "deletable": true, | |
| "editable": true | |
| }, | |
| "outputs": [], | |
| "source": [] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 6, | |
| "metadata": { | |
| "collapsed": true, | |
| "deletable": true, | |
| "editable": true | |
| }, | |
| "outputs": [], | |
| "source": [ | |
| "# just Embedding layer with custom normal initialization\n", | |
| "class ScaledEmbedding(nn.Embedding):\n", | |
| " def reset_parameters(self):\n", | |
| " self.weight.data.normal_(0, 0.1 / self.embedding_dim)\n", | |
| " if self.padding_idx is not None:\n", | |
| " self.weight.data[self.padding_idx].fill_(0)\n", | |
| " \n", | |
| "\n", | |
| "class MatrixFactorization(torch.nn.Module):\n", | |
| " \n", | |
| " def __init__(self, n_users, n_items, n_factors=20):\n", | |
| " super().__init__()\n", | |
| " self.user_factors = ScaledEmbedding(n_users, n_factors)\n", | |
| " self.item_factors = ScaledEmbedding(n_items, n_factors)\n", | |
| " \n", | |
| " def forward(self, user, item):\n", | |
| " return (self.user_factors(user) * self.item_factors(item)).sum(1)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 7, | |
| "metadata": { | |
| "collapsed": true, | |
| "deletable": true, | |
| "editable": true | |
| }, | |
| "outputs": [], | |
| "source": [ | |
| "from torch.optim.lr_scheduler import ReduceLROnPlateau\n", | |
| "\n", | |
| "model = MatrixFactorization(n_users, n_items, n_factors=60)\n", | |
| "mse = torch.nn.MSELoss()\n", | |
| "optimizer = torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-9)\n", | |
| "scheduler = ReduceLROnPlateau(optimizer, factor=0.1, patience=3, threshold=3e-4, threshold_mode='abs', verbose=True)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 10, | |
| "metadata": { | |
| "collapsed": true, | |
| "deletable": true, | |
| "editable": true | |
| }, | |
| "outputs": [], | |
| "source": [ | |
| "# Train" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 11, | |
| "metadata": { | |
| "collapsed": false, | |
| "deletable": true, | |
| "editable": true | |
| }, | |
| "outputs": [], | |
| "source": [ | |
| "def train():\n", | |
| " loss_hist = []\n", | |
| " penalty_hist = []\n", | |
| " for batch_X, batch_y in tqdm.tqdm(tr_iter):\n", | |
| " optimizer.zero_grad()\n", | |
| " \n", | |
| " bX = Variable(batch_X)\n", | |
| " bY = Variable(batch_y.float())\n", | |
| " prediction = model.forward(bX[:, 0], bX[:, 1])\n", | |
| " loss = mse(prediction, bY)\n", | |
| "\n", | |
| " loss_hist.append(loss.data.numpy()[0])\n", | |
| "\n", | |
| " # Backpropagate\n", | |
| " loss.backward()\n", | |
| "\n", | |
| " # Update the parameters\n", | |
| " optimizer.step()\n", | |
| "\n", | |
| " print('mse: {}| rmse: {}'.format(np.mean(loss_hist), np.sqrt(np.mean(loss_hist))))\n", | |
| " return loss_hist" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 12, | |
| "metadata": { | |
| "collapsed": false, | |
| "deletable": true, | |
| "editable": true | |
| }, | |
| "outputs": [], | |
| "source": [ | |
| "# Test" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 13, | |
| "metadata": { | |
| "collapsed": false, | |
| "deletable": true, | |
| "editable": true | |
| }, | |
| "outputs": [], | |
| "source": [ | |
| "def test():\n", | |
| " pred_hist = []\n", | |
| " for batch_X, batch_y in tqdm.tqdm(te_iter):\n", | |
| " bX = Variable(batch_X)\n", | |
| " bY = Variable(batch_y.float())\n", | |
| " prediction = model.forward(bX[:, 0], bX[:, 1])\n", | |
| " pred_hist.append(prediction.data.numpy())\n", | |
| " return np.concatenate(pred_hist) " | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": { | |
| "collapsed": false, | |
| "deletable": true, | |
| "editable": true, | |
| "scrolled": false | |
| }, | |
| "outputs": [ | |
| { | |
| "name": "stderr", | |
| "output_type": "stream", | |
| "text": [ | |
| "100%|██████████| 880/880 [00:11<00:00, 74.39it/s]\n", | |
| " 14%|█▍ | 14/98 [00:00<00:00, 135.49it/s]" | |
| ] | |
| }, | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "mse: 4.653479099273682| rmse: 2.1571924686431885\n" | |
| ] | |
| }, | |
| { | |
| "name": "stderr", | |
| "output_type": "stream", | |
| "text": [ | |
| "100%|██████████| 98/98 [00:00<00:00, 177.92it/s]\n", | |
| " 1%| | 5/880 [00:00<00:18, 47.60it/s]" | |
| ] | |
| }, | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "0 1.01173530424\n" | |
| ] | |
| }, | |
| { | |
| "name": "stderr", | |
| "output_type": "stream", | |
| "text": [ | |
| "100%|██████████| 880/880 [00:12<00:00, 72.27it/s]\n", | |
| " 22%|██▏ | 22/98 [00:00<00:00, 219.62it/s]" | |
| ] | |
| }, | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "mse: 0.9061034917831421| rmse: 0.9518947005271912\n" | |
| ] | |
| }, | |
| { | |
| "name": "stderr", | |
| "output_type": "stream", | |
| "text": [ | |
| "100%|██████████| 98/98 [00:00<00:00, 165.33it/s]\n", | |
| " 0%| | 2/880 [00:00<00:51, 17.15it/s]" | |
| ] | |
| }, | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "1 0.93067693311\n" | |
| ] | |
| }, | |
| { | |
| "name": "stderr", | |
| "output_type": "stream", | |
| "text": [ | |
| "100%|██████████| 880/880 [00:12<00:00, 70.76it/s]\n", | |
| " 21%|██▏ | 21/98 [00:00<00:00, 161.11it/s]" | |
| ] | |
| }, | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "mse: 0.851354718208313| rmse: 0.9226888418197632\n" | |
| ] | |
| }, | |
| { | |
| "name": "stderr", | |
| "output_type": "stream", | |
| "text": [ | |
| "100%|██████████| 98/98 [00:00<00:00, 187.60it/s]\n", | |
| " 1%| | 5/880 [00:00<00:23, 38.01it/s]" | |
| ] | |
| }, | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "2 0.9221900053\n" | |
| ] | |
| }, | |
| { | |
| "name": "stderr", | |
| "output_type": "stream", | |
| "text": [ | |
| "100%|██████████| 880/880 [00:11<00:00, 77.60it/s]\n", | |
| " 14%|█▍ | 14/98 [00:00<00:00, 136.96it/s]" | |
| ] | |
| }, | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "mse: 0.8407105207443237| rmse: 0.9169026613235474\n" | |
| ] | |
| }, | |
| { | |
| "name": "stderr", | |
| "output_type": "stream", | |
| "text": [ | |
| "100%|██████████| 98/98 [00:00<00:00, 175.47it/s]\n", | |
| " 1%| | 5/880 [00:00<00:18, 47.06it/s]" | |
| ] | |
| }, | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "3 0.917020341644\n" | |
| ] | |
| }, | |
| { | |
| "name": "stderr", | |
| "output_type": "stream", | |
| "text": [ | |
| "100%|██████████| 880/880 [00:11<00:00, 76.72it/s]\n", | |
| " 13%|█▎ | 13/98 [00:00<00:00, 128.18it/s]" | |
| ] | |
| }, | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "mse: 0.8308556079864502| rmse: 0.9115127921104431\n" | |
| ] | |
| }, | |
| { | |
| "name": "stderr", | |
| "output_type": "stream", | |
| "text": [ | |
| "100%|██████████| 98/98 [00:00<00:00, 173.40it/s]\n", | |
| " 1%| | 5/880 [00:00<00:18, 47.75it/s]" | |
| ] | |
| }, | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "4 0.911672158106\n" | |
| ] | |
| }, | |
| { | |
| "name": "stderr", | |
| "output_type": "stream", | |
| "text": [ | |
| "100%|██████████| 880/880 [00:11<00:00, 78.59it/s]\n", | |
| " 21%|██▏ | 21/98 [00:00<00:00, 202.82it/s]" | |
| ] | |
| }, | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "mse: 0.8189589977264404| rmse: 0.9049635529518127\n" | |
| ] | |
| }, | |
| { | |
| "name": "stderr", | |
| "output_type": "stream", | |
| "text": [ | |
| "100%|██████████| 98/98 [00:00<00:00, 191.10it/s]\n", | |
| " 1%| | 5/880 [00:00<00:18, 46.53it/s]" | |
| ] | |
| }, | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "5 0.907034108872\n" | |
| ] | |
| }, | |
| { | |
| "name": "stderr", | |
| "output_type": "stream", | |
| "text": [ | |
| "100%|██████████| 880/880 [00:11<00:00, 77.61it/s]\n", | |
| " 17%|█▋ | 17/98 [00:00<00:00, 141.66it/s]" | |
| ] | |
| }, | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "mse: 0.8076307773590088| rmse: 0.8986827731132507\n" | |
| ] | |
| }, | |
| { | |
| "name": "stderr", | |
| "output_type": "stream", | |
| "text": [ | |
| "100%|██████████| 98/98 [00:00<00:00, 173.19it/s]\n", | |
| " 1%| | 5/880 [00:00<00:18, 48.50it/s]" | |
| ] | |
| }, | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "6 0.902936388422\n" | |
| ] | |
| }, | |
| { | |
| "name": "stderr", | |
| "output_type": "stream", | |
| "text": [ | |
| "100%|██████████| 880/880 [00:11<00:00, 76.86it/s]\n", | |
| " 19%|█▉ | 19/98 [00:00<00:00, 185.28it/s]" | |
| ] | |
| }, | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "mse: 0.7959813475608826| rmse: 0.8921778798103333\n" | |
| ] | |
| }, | |
| { | |
| "name": "stderr", | |
| "output_type": "stream", | |
| "text": [ | |
| "100%|██████████| 98/98 [00:00<00:00, 172.34it/s]\n", | |
| " 1%| | 5/880 [00:00<00:18, 47.96it/s]" | |
| ] | |
| }, | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "7 0.898566136779\n" | |
| ] | |
| }, | |
| { | |
| "name": "stderr", | |
| "output_type": "stream", | |
| "text": [ | |
| "100%|██████████| 880/880 [00:11<00:00, 74.25it/s]\n", | |
| " 13%|█▎ | 13/98 [00:00<00:00, 128.63it/s]" | |
| ] | |
| }, | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "mse: 0.7819766402244568| rmse: 0.8842944502830505\n" | |
| ] | |
| }, | |
| { | |
| "name": "stderr", | |
| "output_type": "stream", | |
| "text": [ | |
| "100%|██████████| 98/98 [00:00<00:00, 169.65it/s]\n", | |
| " 1%| | 5/880 [00:00<00:18, 47.47it/s]" | |
| ] | |
| }, | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "8 0.891737408583\n" | |
| ] | |
| }, | |
| { | |
| "name": "stderr", | |
| "output_type": "stream", | |
| "text": [ | |
| "100%|██████████| 880/880 [00:11<00:00, 74.63it/s]\n", | |
| " 20%|██ | 20/98 [00:00<00:00, 154.51it/s]" | |
| ] | |
| }, | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "mse: 0.7655534148216248| rmse: 0.8749591112136841\n" | |
| ] | |
| }, | |
| { | |
| "name": "stderr", | |
| "output_type": "stream", | |
| "text": [ | |
| "100%|██████████| 98/98 [00:00<00:00, 187.24it/s]\n", | |
| " 0%| | 3/880 [00:00<00:31, 27.48it/s]" | |
| ] | |
| }, | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "9 0.884484992531\n" | |
| ] | |
| }, | |
| { | |
| "name": "stderr", | |
| "output_type": "stream", | |
| "text": [ | |
| "100%|██████████| 880/880 [00:12<00:00, 78.19it/s]\n", | |
| " 14%|█▍ | 14/98 [00:00<00:00, 138.73it/s]" | |
| ] | |
| }, | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "mse: 0.7477543354034424| rmse: 0.8647279143333435\n" | |
| ] | |
| }, | |
| { | |
| "name": "stderr", | |
| "output_type": "stream", | |
| "text": [ | |
| "100%|██████████| 98/98 [00:00<00:00, 174.74it/s]\n", | |
| " 1%| | 5/880 [00:00<00:18, 46.39it/s]" | |
| ] | |
| }, | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "10 0.878235862071\n" | |
| ] | |
| }, | |
| { | |
| "name": "stderr", | |
| "output_type": "stream", | |
| "text": [ | |
| "100%|██████████| 880/880 [00:11<00:00, 77.35it/s]\n", | |
| " 13%|█▎ | 13/98 [00:00<00:00, 127.24it/s]" | |
| ] | |
| }, | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "mse: 0.7289107441902161| rmse: 0.853762686252594\n" | |
| ] | |
| }, | |
| { | |
| "name": "stderr", | |
| "output_type": "stream", | |
| "text": [ | |
| "100%|██████████| 98/98 [00:00<00:00, 173.27it/s]\n", | |
| " 1%| | 5/880 [00:00<00:18, 48.17it/s]" | |
| ] | |
| }, | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "11 0.872880039698\n" | |
| ] | |
| }, | |
| { | |
| "name": "stderr", | |
| "output_type": "stream", | |
| "text": [ | |
| "100%|██████████| 880/880 [00:11<00:00, 76.73it/s]\n", | |
| " 19%|█▉ | 19/98 [00:00<00:00, 189.53it/s]" | |
| ] | |
| }, | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "mse: 0.708349883556366| rmse: 0.8416352272033691\n" | |
| ] | |
| }, | |
| { | |
| "name": "stderr", | |
| "output_type": "stream", | |
| "text": [ | |
| "100%|██████████| 98/98 [00:00<00:00, 184.58it/s]\n", | |
| " 1%| | 5/880 [00:00<00:18, 47.70it/s]" | |
| ] | |
| }, | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "12 0.866932788908\n" | |
| ] | |
| }, | |
| { | |
| "name": "stderr", | |
| "output_type": "stream", | |
| "text": [ | |
| "100%|██████████| 880/880 [00:11<00:00, 75.45it/s]\n", | |
| " 20%|██ | 20/98 [00:00<00:00, 191.45it/s]" | |
| ] | |
| }, | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "mse: 0.6866269707679749| rmse: 0.8286295533180237\n" | |
| ] | |
| }, | |
| { | |
| "name": "stderr", | |
| "output_type": "stream", | |
| "text": [ | |
| "100%|██████████| 98/98 [00:00<00:00, 184.78it/s]\n", | |
| " 0%| | 4/880 [00:00<00:22, 38.22it/s]" | |
| ] | |
| }, | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "13 0.861228102471\n" | |
| ] | |
| }, | |
| { | |
| "name": "stderr", | |
| "output_type": "stream", | |
| "text": [ | |
| "100%|██████████| 880/880 [00:12<00:00, 68.93it/s]\n", | |
| " 12%|█▏ | 12/98 [00:00<00:00, 112.82it/s]" | |
| ] | |
| }, | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "mse: 0.6639280319213867| rmse: 0.8148177862167358\n" | |
| ] | |
| }, | |
| { | |
| "name": "stderr", | |
| "output_type": "stream", | |
| "text": [ | |
| "100%|██████████| 98/98 [00:00<00:00, 133.63it/s]\n", | |
| " 0%| | 3/880 [00:00<00:32, 26.69it/s]" | |
| ] | |
| }, | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "14 0.857208540688\n" | |
| ] | |
| }, | |
| { | |
| "name": "stderr", | |
| "output_type": "stream", | |
| "text": [ | |
| "100%|██████████| 880/880 [00:12<00:00, 68.46it/s]\n", | |
| " 13%|█▎ | 13/98 [00:00<00:00, 128.50it/s]" | |
| ] | |
| }, | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "mse: 0.6401769518852234| rmse: 0.8001105785369873\n" | |
| ] | |
| }, | |
| { | |
| "name": "stderr", | |
| "output_type": "stream", | |
| "text": [ | |
| "100%|██████████| 98/98 [00:00<00:00, 175.45it/s]\n", | |
| " 1%| | 5/880 [00:00<00:19, 45.31it/s]" | |
| ] | |
| }, | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "15 0.854500042774\n" | |
| ] | |
| }, | |
| { | |
| "name": "stderr", | |
| "output_type": "stream", | |
| "text": [ | |
| "100%|██████████| 880/880 [00:11<00:00, 76.96it/s]\n", | |
| " 11%|█ | 11/98 [00:00<00:00, 109.56it/s]" | |
| ] | |
| }, | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "mse: 0.6150626540184021| rmse: 0.7842593193054199\n" | |
| ] | |
| }, | |
| { | |
| "name": "stderr", | |
| "output_type": "stream", | |
| "text": [ | |
| "100%|██████████| 98/98 [00:00<00:00, 170.92it/s]\n", | |
| " 1%| | 5/880 [00:00<00:18, 47.23it/s]" | |
| ] | |
| }, | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "16 0.852726459753\n" | |
| ] | |
| }, | |
| { | |
| "name": "stderr", | |
| "output_type": "stream", | |
| "text": [ | |
| "100%|██████████| 880/880 [00:11<00:00, 74.62it/s]\n", | |
| " 20%|██ | 20/98 [00:00<00:00, 196.20it/s]" | |
| ] | |
| }, | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "mse: 0.5880409479141235| rmse: 0.7668382525444031\n" | |
| ] | |
| }, | |
| { | |
| "name": "stderr", | |
| "output_type": "stream", | |
| "text": [ | |
| "100%|██████████| 98/98 [00:00<00:00, 185.30it/s]\n", | |
| " 1%| | 5/880 [00:00<00:17, 48.76it/s]" | |
| ] | |
| }, | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "17 0.85159162539\n" | |
| ] | |
| }, | |
| { | |
| "name": "stderr", | |
| "output_type": "stream", | |
| "text": [ | |
| "100%|██████████| 880/880 [00:11<00:00, 77.37it/s]\n", | |
| " 19%|█▉ | 19/98 [00:00<00:00, 189.47it/s]" | |
| ] | |
| }, | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "mse: 0.5601016879081726| rmse: 0.7483994364738464\n" | |
| ] | |
| }, | |
| { | |
| "name": "stderr", | |
| "output_type": "stream", | |
| "text": [ | |
| "100%|██████████| 98/98 [00:00<00:00, 183.51it/s]\n", | |
| " 1%| | 5/880 [00:00<00:23, 37.87it/s]" | |
| ] | |
| }, | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "18 0.851854084252\n" | |
| ] | |
| }, | |
| { | |
| "name": "stderr", | |
| "output_type": "stream", | |
| "text": [ | |
| "100%|██████████| 880/880 [00:11<00:00, 76.52it/s]\n", | |
| " 14%|█▍ | 14/98 [00:00<00:00, 131.52it/s]" | |
| ] | |
| }, | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "mse: 0.5313001871109009| rmse: 0.7289034128189087\n" | |
| ] | |
| }, | |
| { | |
| "name": "stderr", | |
| "output_type": "stream", | |
| "text": [ | |
| "100%|██████████| 98/98 [00:00<00:00, 171.02it/s]\n", | |
| " 1%| | 5/880 [00:00<00:18, 48.28it/s]" | |
| ] | |
| }, | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "19 0.853770273338\n" | |
| ] | |
| }, | |
| { | |
| "name": "stderr", | |
| "output_type": "stream", | |
| "text": [ | |
| "100%|██████████| 880/880 [00:11<00:00, 73.87it/s]\n", | |
| " 13%|█▎ | 13/98 [00:00<00:00, 128.89it/s]" | |
| ] | |
| }, | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "mse: 0.5032721161842346| rmse: 0.7094167470932007\n" | |
| ] | |
| }, | |
| { | |
| "name": "stderr", | |
| "output_type": "stream", | |
| "text": [ | |
| "100%|██████████| 98/98 [00:00<00:00, 171.69it/s]\n", | |
| " 1%| | 5/880 [00:00<00:17, 48.91it/s]" | |
| ] | |
| }, | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "20 0.857647674373\n" | |
| ] | |
| }, | |
| { | |
| "name": "stderr", | |
| "output_type": "stream", | |
| "text": [ | |
| " 25%|██▍ | 217/880 [00:02<00:08, 73.79it/s]" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "val_hist = []\n", | |
| "for i in range(50):\n", | |
| " loss_hist = train()\n", | |
| " pred = test()\n", | |
| " val_err = np.sqrt(mean_squared_error(data['y_te'], pred))\n", | |
| " scheduler.step(val_err)\n", | |
| " val_hist.append(val_err)\n", | |
| " print(i, val_err)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": { | |
| "collapsed": true, | |
| "deletable": true, | |
| "editable": true | |
| }, | |
| "outputs": [], | |
| "source": [ | |
| "# SVD-20: 0.8623" | |
| ] | |
| } | |
| ], | |
| "metadata": { | |
| "kernelspec": { | |
| "display_name": "Python 3", | |
| "language": "python", | |
| "name": "python3" | |
| }, | |
| "language_info": { | |
| "codemirror_mode": { | |
| "name": "ipython", | |
| "version": 3 | |
| }, | |
| "file_extension": ".py", | |
| "mimetype": "text/x-python", | |
| "name": "python", | |
| "nbconvert_exporter": "python", | |
| "pygments_lexer": "ipython3", | |
| "version": "3.6.1" | |
| } | |
| }, | |
| "nbformat": 4, | |
| "nbformat_minor": 2 | |
| } |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment