Created
November 18, 2018 05:17
-
-
Save NTT123/d671e8f0eb78dab36a60c66c5219db54 to your computer and use it in GitHub Desktop.
Positional Attention 4 Simple Mapping
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"nbformat": 4, | |
"nbformat_minor": 0, | |
"metadata": { | |
"colab": { | |
"name": "Positional Attention 4 Simple Mapping", | |
"version": "0.3.2", | |
"provenance": [], | |
"include_colab_link": true | |
}, | |
"kernelspec": { | |
"name": "python3", | |
"display_name": "Python 3" | |
}, | |
"accelerator": "GPU" | |
}, | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "view-in-github", | |
"colab_type": "text" | |
}, | |
"source": [ | |
"<a href=\"https://colab.research.google.com/gist/NTT123/d671e8f0eb78dab36a60c66c5219db54/positional-attention-4-simple-mapping.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>" | |
] | |
}, | |
{ | |
"metadata": { | |
"id": "gD2g-9Xi6kQ6", | |
"colab_type": "text" | |
}, | |
"cell_type": "markdown", | |
"source": [ | |
"## Setup" | |
] | |
}, | |
{ | |
"metadata": { | |
"id": "PKJnnnIm2Mb2", | |
"colab_type": "code", | |
"colab": {} | |
}, | |
"cell_type": "code", | |
"source": [ | |
"!pip install -q numpy torchvision_nightly\n", | |
"!pip install -q torch_nightly -f https://download.pytorch.org/whl/nightly/cu92/torch_nightly.html\n", | |
"\n", | |
"import matplotlib.pyplot as plt\n", | |
"import torch\n", | |
"import numpy as np\n", | |
"from torch.utils.data import Dataset\n", | |
"import math\n", | |
"DEV = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")" | |
], | |
"execution_count": 0, | |
"outputs": [] | |
}, | |
{ | |
"metadata": { | |
"id": "LQ_W_TqOx8w2", | |
"colab_type": "text" | |
}, | |
"cell_type": "markdown", | |
"source": [ | |
"## Data" | |
] | |
}, | |
{ | |
"metadata": { | |
"id": "bSLEf3y-CAg8", | |
"colab_type": "code", | |
"colab": {} | |
}, | |
"cell_type": "code", | |
"source": [ | |
"master_dic = {0:9, 1: 3, 2: 6, 3: 1, 4: 9, 5: 7, 6: 8, 7: 0, 8: 4, 9: 5}\n", | |
"\n", | |
"def generate_data_point(l):\n", | |
" x = np.random.randint(low=0, high=10, size=(l,))\n", | |
" y = np.copy(x)\n", | |
" for i in range(len(x)):\n", | |
" y[i] = master_dic[ x[i]]\n", | |
" \n", | |
" return (x, y)\n", | |
"\n", | |
"def generate_data_set(size, l):\n", | |
" return [generate_data_point(l) for _ in range(size)]\n" | |
], | |
"execution_count": 0, | |
"outputs": [] | |
}, | |
{ | |
"metadata": { | |
"id": "G-XsLnynERbK", | |
"colab_type": "code", | |
"colab": {} | |
}, | |
"cell_type": "code", | |
"source": [ | |
"class MDS(Dataset):\n", | |
" \"\"\"Number mapping dataset.\"\"\"\n", | |
"\n", | |
" def __init__(self, size=50000, l=100):\n", | |
" self.data = generate_data_set(size, l)\n", | |
" \n", | |
" def __len__(self):\n", | |
" return len(self.data)\n", | |
" \n", | |
" def __getitem__(self, idx):\n", | |
" return self.data[idx]\n" | |
], | |
"execution_count": 0, | |
"outputs": [] | |
}, | |
{ | |
"metadata": { | |
"id": "0NhNO8oxEh57", | |
"colab_type": "code", | |
"colab": {} | |
}, | |
"cell_type": "code", | |
"source": [ | |
"dataset = MDS()" | |
], | |
"execution_count": 0, | |
"outputs": [] | |
}, | |
{ | |
"metadata": { | |
"id": "sExqYgqUUHCl", | |
"colab_type": "text" | |
}, | |
"cell_type": "markdown", | |
"source": [ | |
"## Model" | |
] | |
}, | |
{ | |
"metadata": { | |
"id": "mNpm8hSKH_3g", | |
"colab_type": "code", | |
"colab": {} | |
}, | |
"cell_type": "code", | |
"source": [ | |
"class PosAtt(torch.nn.Module):\n", | |
" \n", | |
" __constants__ = ['input_size', 'num_heads']\n", | |
" \n", | |
" def __init__(self, input_size, num_heads):\n", | |
" super().__init__()\n", | |
" \n", | |
" self.input_size = input_size\n", | |
" self.num_heads = num_heads\n", | |
" \n", | |
" self.linear = torch.nn.Linear(input_size, 3*num_heads)\n", | |
" \n", | |
" def forward(self, x, mem, prev_r, pos):\n", | |
" bs = x.size(0)\n", | |
" a, b, deltar = self.linear(x).view(bs, self.num_heads, 3).chunk(3, dim=2)\n", | |
" \n", | |
" l = mem.size(1) # length of the memory, assumping the same lengths for a min-batch\n", | |
" \n", | |
" \n", | |
" r = prev_r + deltar.exp()\n", | |
" \n", | |
" z = -b.exp() * torch.pow(r - pos, 2)\n", | |
" z = torch.softmax(z, dim=2)\n", | |
" \n", | |
" w = torch.sum(torch.sigmoid(a) * z, dim=1).view(bs, 1, l)\n", | |
" mem = mem.view(bs, l, -1)\n", | |
" \n", | |
" v = torch.bmm(w, mem).squeeze(1) # bs, [dim]\n", | |
" \n", | |
" return r, v, w\n", | |
" \n", | |
" def init_location(self, bs):\n", | |
" return -torch.ones(bs, self.num_heads, 1)\n" | |
], | |
"execution_count": 0, | |
"outputs": [] | |
}, | |
{ | |
"metadata": { | |
"id": "ouOXedJjFpgN", | |
"colab_type": "code", | |
"colab": {} | |
}, | |
"cell_type": "code", | |
"source": [ | |
"class Mapper(torch.nn.Module):\n", | |
" def __init__(self, embedding_size=32, hidden_size=64):\n", | |
" super().__init__()\n", | |
" \n", | |
" self.hidden_size = hidden_size\n", | |
" self.embedding_size = embedding_size\n", | |
" \n", | |
" self.embedding = torch.nn.Embedding(10, embedding_size)\n", | |
" self.rnn_cell = torch.nn.GRUCell(embedding_size, hidden_size)\n", | |
" self.posatt = PosAtt(hidden_size, 5)\n", | |
" \n", | |
" self.linear = torch.nn.Linear(embedding_size, 10)\n", | |
"\n", | |
" self.register_parameter('init_hx', \n", | |
" torch.nn.Parameter(\n", | |
" torch.randn(1, hidden_size)) \n", | |
" )\n", | |
" \n", | |
" def forward(self, x, y):\n", | |
" x = self.embedding(x)\n", | |
" y = self.embedding(y)\n", | |
"\n", | |
" l = y.size(1)\n", | |
" \n", | |
" pos = torch.arange(end=l, \n", | |
" dtype=torch.float, \n", | |
" device=DEV, requires_grad=False) \n", | |
" \n", | |
" pos = pos.view(1, 1, -1)\n", | |
" \n", | |
" \n", | |
" bs = x.size(0)\n", | |
" seq_len = x.size(1)\n", | |
" \n", | |
" hidden_state = self.init_hidden_states(bs)\n", | |
" prev_r = self.posatt.init_location(bs).to(DEV)\n", | |
" \n", | |
" hidden_state = self.rnn_cell(torch.zeros_like(y[:, 0]), hidden_state)\n", | |
" prev_r, v, w = self.posatt(hidden_state, x, prev_r, pos)\n", | |
"\n", | |
" list_of_v = [v]\n", | |
" lw = [w]\n", | |
" \n", | |
" for i in range(seq_len-1):\n", | |
" hidden_state = self.rnn_cell(y[:, i], hidden_state)\n", | |
" \n", | |
" prev_r, v, w = self.posatt(hidden_state, x, prev_r, pos)\n", | |
" lw.append(w)\n", | |
" list_of_v.append(v)\n", | |
" \n", | |
" vv = torch.stack(list_of_v, dim=1) \n", | |
" \n", | |
" lw = torch.stack(lw, dim=0)\n", | |
" \n", | |
" return self.linear(vv), lw\n", | |
" \n", | |
" def init_hidden_states(self, bs):\n", | |
" return self.init_hx.expand(bs, self.hidden_size)\n" | |
], | |
"execution_count": 0, | |
"outputs": [] | |
}, | |
{ | |
"metadata": { | |
"id": "nYOdlJ3_GF9o", | |
"colab_type": "text" | |
}, | |
"cell_type": "markdown", | |
"source": [ | |
"## Warmup" | |
] | |
}, | |
{ | |
"metadata": { | |
"id": "ft5mV3NuF_FU", | |
"colab_type": "code", | |
"colab": {} | |
}, | |
"cell_type": "code", | |
"source": [ | |
"mapper = Mapper().to(DEV)" | |
], | |
"execution_count": 0, | |
"outputs": [] | |
}, | |
{ | |
"metadata": { | |
"id": "e76tItv8GC9y", | |
"colab_type": "code", | |
"colab": {} | |
}, | |
"cell_type": "code", | |
"source": [ | |
"optimizer = torch.optim.Adam(mapper.parameters(), lr=1e-4)\n", | |
"lossfn = torch.nn.CrossEntropyLoss()\n", | |
"\n", | |
"delta = 10\n", | |
"sl = 0.0\n", | |
"\n", | |
"epoch = 0\n", | |
"\n", | |
"import sys\n", | |
"import datetime" | |
], | |
"execution_count": 0, | |
"outputs": [] | |
}, | |
{ | |
"metadata": { | |
"id": "YaRpopM6H6Dn", | |
"colab_type": "code", | |
"colab": {} | |
}, | |
"cell_type": "code", | |
"source": [ | |
"testset = MDS(1000,300)" | |
], | |
"execution_count": 0, | |
"outputs": [] | |
}, | |
{ | |
"metadata": { | |
"id": "JcqWooPtHZaq", | |
"colab_type": "code", | |
"colab": {} | |
}, | |
"cell_type": "code", | |
"source": [ | |
" def test(model, testset):\n", | |
" model.eval()\n", | |
" \n", | |
" dataset_loader = torch.utils.data.DataLoader(testset, shuffle=False, batch_size=100)\n", | |
" \n", | |
" sl = 0.0\n", | |
" \n", | |
" past = datetime.datetime.now()\n", | |
" for i, (x, y) in enumerate(dataset_loader):\n", | |
"\n", | |
" x = x.to(DEV)\n", | |
" y = y.to(DEV)\n", | |
" \n", | |
" out, w = mapper(x, y)\n", | |
" loss = 0.0\n", | |
" \n", | |
" for j in range(out.size(1)):\n", | |
" loss += lossfn(out[:, j,:], y[:, j])\n", | |
" \n", | |
" sl = sl + loss.item()\n", | |
" print(f\"\\nTest loss: {sl / i / x.size(1)}\")\n", | |
" model.train()" | |
], | |
"execution_count": 0, | |
"outputs": [] | |
}, | |
{ | |
"metadata": { | |
"id": "AB8swiw6INPO", | |
"colab_type": "code", | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 51 | |
}, | |
"outputId": "931d71c8-5098-4799-a6ab-940e08c3d092" | |
}, | |
"cell_type": "code", | |
"source": [ | |
"test(mapper, testset)" | |
], | |
"execution_count": 11, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"text": [ | |
"\n", | |
"Test loss: 3.191653035481771\n" | |
], | |
"name": "stdout" | |
} | |
] | |
}, | |
{ | |
"metadata": { | |
"id": "5yv45iZtUJ77", | |
"colab_type": "text" | |
}, | |
"cell_type": "markdown", | |
"source": [ | |
"## Train" | |
] | |
}, | |
{ | |
"metadata": { | |
"id": "Q8kTEd7JVpnP", | |
"colab_type": "code", | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 1680 | |
}, | |
"outputId": "6387b0a4-c5cd-46a1-9ae2-c21321584999" | |
}, | |
"cell_type": "code", | |
"source": [ | |
"start = epoch\n", | |
"\n", | |
"mapper.train()\n", | |
"\n", | |
"for epoch in range(start, 1000):\n", | |
" dataset_loader = torch.utils.data.DataLoader(dataset, shuffle=True, batch_size=1024)\n", | |
" \n", | |
" sl = 0.0\n", | |
" \n", | |
" past = datetime.datetime.now()\n", | |
" for i, (x, y) in enumerate(dataset_loader):\n", | |
"\n", | |
" x = x.to(DEV)\n", | |
" y = y.to(DEV)\n", | |
" \n", | |
" out, w = mapper(x, y)\n", | |
" loss = 0.0\n", | |
" \n", | |
" for j in range(out.size(1)):\n", | |
" loss += lossfn(out[:, j,:], y[:, j])\n", | |
" \n", | |
" optimizer.zero_grad()\n", | |
" loss.backward()\n", | |
" torch.nn.utils.clip_grad_norm_(mapper.parameters(), 1.0) \n", | |
" optimizer.step()\n", | |
" sl = sl + loss.item()\n", | |
" if i % delta == delta-1:\n", | |
" now = datetime.datetime.now()\n", | |
"\n", | |
" sys.stdout.write(f\"\\r {epoch} {i*100.0/len(dataset_loader): .3f} *** {sl/delta/x.size(1): .3f} *** {(now-past).seconds} sec\")\n", | |
" \n", | |
" sl = 0.0\n", | |
" past = now\n", | |
" \n", | |
" test(mapper, testset)" | |
], | |
"execution_count": 12, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"text": [ | |
" 0 79.592 *** 2.779 *** 5 sec\n", | |
"Test loss: 3.014451429578993\n", | |
" 1 79.592 *** 2.626 *** 5 sec\n", | |
"Test loss: 2.8575190339265046\n", | |
" 2 79.592 *** 2.473 *** 5 sec\n", | |
"Test loss: 2.69341222692419\n", | |
" 3 79.592 *** 2.355 *** 5 sec\n", | |
"Test loss: 2.6040566225405093\n", | |
" 4 79.592 *** 2.307 *** 5 sec\n", | |
"Test loss: 2.5639827247902196\n", | |
" 5 79.592 *** 2.276 *** 5 sec\n", | |
"Test loss: 2.5290764928747107\n", | |
" 6 79.592 *** 2.244 *** 5 sec\n", | |
"Test loss: 2.4884056939019095\n", | |
" 7 79.592 *** 2.194 *** 5 sec\n", | |
"Test loss: 2.422820864076968\n", | |
" 8 79.592 *** 2.108 *** 5 sec\n", | |
"Test loss: 2.331888043438947\n", | |
" 9 79.592 *** 1.997 *** 5 sec\n", | |
"Test loss: 2.193220757378472\n", | |
" 10 79.592 *** 1.845 *** 5 sec\n", | |
"Test loss: 2.0141206868489583\n", | |
" 11 79.592 *** 1.675 *** 5 sec\n", | |
"Test loss: 1.8539030626085071\n", | |
" 12 79.592 *** 1.455 *** 5 sec\n", | |
"Test loss: 1.7659500009042244\n", | |
" 13 79.592 *** 1.259 *** 5 sec\n", | |
"Test loss: 1.387509686505353\n", | |
" 14 79.592 *** 1.097 *** 5 sec\n", | |
"Test loss: 1.2019114515516494\n", | |
" 15 79.592 *** 0.846 *** 5 sec\n", | |
"Test loss: 0.8759097911693433\n", | |
" 16 79.592 *** 0.570 *** 5 sec\n", | |
"Test loss: 0.5826709662543403\n", | |
" 17 79.592 *** 0.397 *** 5 sec\n", | |
"Test loss: 0.41130372789171005\n", | |
" 18 79.592 *** 0.263 *** 5 sec\n", | |
"Test loss: 0.2667104396113643\n", | |
" 19 79.592 *** 0.161 *** 5 sec\n", | |
"Test loss: 0.15735791665536386\n", | |
" 20 79.592 *** 0.076 *** 5 sec\n", | |
"Test loss: 0.0651626968383789\n", | |
" 21 79.592 *** 0.025 *** 5 sec\n", | |
"Test loss: 0.02437548107571072\n", | |
" 22 79.592 *** 0.008 *** 5 sec\n", | |
"Test loss: 0.0069902514528345176\n", | |
" 23 79.592 *** 0.002 *** 5 sec\n", | |
"Test loss: 0.0035060937537087335\n", | |
" 24 79.592 *** 0.001 *** 5 sec\n", | |
"Test loss: 0.001150429039089768\n", | |
" 25 79.592 *** 0.001 *** 5 sec\n", | |
"Test loss: 0.008174553756360654\n", | |
" 26 79.592 *** 0.000 *** 5 sec\n", | |
"Test loss: 0.0011200149302129391\n", | |
" 27 79.592 *** 0.000 *** 5 sec\n", | |
"Test loss: 0.00038300703245180626\n", | |
" 28 79.592 *** 0.000 *** 5 sec\n", | |
"Test loss: 0.0019341255669240599\n", | |
" 29 79.592 *** 0.000 *** 5 sec\n", | |
"Test loss: 0.004651734166675144\n", | |
" 30 79.592 *** 0.000 *** 5 sec\n", | |
"Test loss: 0.004434576144924871\n", | |
" 31 79.592 *** 0.000 *** 5 sec\n", | |
"Test loss: 0.0036194379462136167\n", | |
" 32 79.592 *** 0.000 *** 5 sec\n", | |
"Test loss: 0.0012328265441788567\n", | |
" 33 79.592 *** 0.000 *** 5 sec\n", | |
"Test loss: 0.0008667809488596739\n", | |
" 34 79.592 *** 0.000 *** 5 sec\n", | |
"Test loss: 0.001637578860477165\n" | |
], | |
"name": "stdout" | |
}, | |
{ | |
"output_type": "error", | |
"ename": "KeyboardInterrupt", | |
"evalue": "ignored", | |
"traceback": [ | |
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", | |
"\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)", | |
"\u001b[0;32m<ipython-input-12-890d3728e75b>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m()\u001b[0m\n\u001b[1;32m 21\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 22\u001b[0m \u001b[0moptimizer\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mzero_grad\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 23\u001b[0;31m \u001b[0mloss\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbackward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 24\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnn\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mutils\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mclip_grad_norm_\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmapper\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mparameters\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m1.0\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 25\u001b[0m \u001b[0moptimizer\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mstep\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", | |
"\u001b[0;32m/usr/local/lib/python3.6/dist-packages/torch/tensor.py\u001b[0m in \u001b[0;36mbackward\u001b[0;34m(self, gradient, retain_graph, create_graph)\u001b[0m\n\u001b[1;32m 100\u001b[0m \u001b[0mproducts\u001b[0m\u001b[0;34m.\u001b[0m \u001b[0mDefaults\u001b[0m \u001b[0mto\u001b[0m\u001b[0;31m \u001b[0m\u001b[0;31m`\u001b[0m\u001b[0;31m`\u001b[0m\u001b[0;32mFalse\u001b[0m\u001b[0;31m`\u001b[0m\u001b[0;31m`\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 101\u001b[0m \"\"\"\n\u001b[0;32m--> 102\u001b[0;31m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mautograd\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbackward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mgradient\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mretain_graph\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcreate_graph\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 103\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 104\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mregister_hook\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mhook\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", | |
"\u001b[0;32m/usr/local/lib/python3.6/dist-packages/torch/autograd/__init__.py\u001b[0m in \u001b[0;36mbackward\u001b[0;34m(tensors, grad_tensors, retain_graph, create_graph, grad_variables)\u001b[0m\n\u001b[1;32m 88\u001b[0m Variable._execution_engine.run_backward(\n\u001b[1;32m 89\u001b[0m \u001b[0mtensors\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mgrad_tensors\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mretain_graph\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcreate_graph\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 90\u001b[0;31m allow_unreachable=True) # allow_unreachable flag\n\u001b[0m\u001b[1;32m 91\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 92\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", | |
"\u001b[0;31mKeyboardInterrupt\u001b[0m: " | |
] | |
} | |
] | |
}, | |
{ | |
"metadata": { | |
"id": "lxKL86-pr3Zq", | |
"colab_type": "code", | |
"outputId": "90ca02da-e97e-480e-a74e-eb9606bac86d", | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 372 | |
} | |
}, | |
"cell_type": "code", | |
"source": [ | |
"plt.matshow(w[0:50,4,0,0:50].cpu().detach().numpy())\n", | |
"plt.axis(\"off\")" | |
], | |
"execution_count": 13, | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": [ | |
"(-0.5, 49.5, 49.5, -0.5)" | |
] | |
}, | |
"metadata": { | |
"tags": [] | |
}, | |
"execution_count": 13 | |
}, | |
{ | |
"output_type": "display_data", | |
"data": { | |
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAVMAAAFSCAYAAABPFzzRAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMS4yLCBo\ndHRwOi8vbWF0cGxvdGxpYi5vcmcvNQv5yAAABilJREFUeJzt3V1u2ggYhlEbyAqy/x1mBxEw11PZ\nNCUP/uOcS6pRqNQ+80mvoeP9fr8PAPzKae03AHAEYgoQEFOAgJgCBMQUICCmAAExBQiIKUBATAEC\nYgoQEFOAgJgCBMQUICCmAAExBQiIKUBATAECYgoQEFOAgJgCBMQUICCmAAExBQhclvghn5+fk69/\nfX29/GeP4/jynwHgMgUIiClAQEwBAmIKEBBTgMB4v9/va/3wj4+Pyde/v7+znzH327PyAyWXKUBA\nTAECYgoQEFOAgJgCBFZd8+ecTtONv91uC78TgJ9xmQIExBQgIKYAATEFCIgpQGCTa/6cR5+n39Fv\nAzgglylAQEwBAmIKEBBTgICYAgTEFCBwWfsN/ItHjz/NPTblkSlgCS5TgICYAgTEFCAgpgABMQUI\n7GrNf2RutbfyA0twmQIExBQgIKYAATEFCIgpQOAwa/4cKz+wBJcpQEBMAQJiChAQU4CAmAIEDr/m\nz7HyAyWXKUBATAECYgoQEFOAgJgCBN52zZ9j5Qee4TIFCIgpQEBMAQJiChAQU4CANf+HrPzAIy5T\ngICYAgTEFCAgpgABMQUIWPN/6V9X/kf/DbBfLlOAgJgCBMQUICCmAAExBQiIKUDAo1Ev8ujxp9Np\n+v9ht9vtVW8HeDGXKUBATAECYgoQEFOAgJgCBMa7b93YDCs/7JfLFCAgpgABMQUIiClAQEwBAtb8\nHbDyw/a5TAECYgoQEFOAgJgCBMQUIGDN3zErP2yHyxQgIKYAATEFCIgpQEBMAQLW/AM6n8+Tr1+v\n14XfCbwPlylAQEwBAmIKEBBTgICYAgSs+W/EZ/nhdVymAAExBQiIKUBATAECYgoQsOZj5YeAyxQg\nIKYAATEFCIgpQEBMAQJiChDwaBSz5h6ZGgaPTcGfXKYAATEFCIgpQEBMAQJiChCw5vOUcRwnX/fH\niXflMgUIiClAQEwBAmIKEBBTgMBl7TfAPs2t9v4JFN6VyxQgIKYAATEFCIgpQEBMAQI+m88irPwc\nncsUICCmAAExBQiIKUBATAEC1nxWZeXnKFymAAExBQiIKUBATAECYgoQsOazSVZ+9sZlChAQU4CA\nmAIExBQgIKYAAWs+u2LlZ6tcpgABMQUIiClAQEwBAmIKELDmcwhzK/8wWPpZhssUICCmAAExBQiI\nKUBATAECYgoQ8GgUh+fLUViCyxQgIKYAATEFCIgpQEBMAQLWfN7WOI6Tr/srwTNcpgABMQUIiClA\nQEwBAmIKELis/QZgLXOrvc/y8wyXKUBATAECYgoQEFOAgJgCBHw2H37Iys8jLlOAgJgCBMQUICCm\nAAExBQhY8+GXrPwMg8sUICGmAAExBQiIKUBATAEC1nx4ESv/e3GZAgTEFCAgpgABMQUIiClAwJoP\nC5tb+YfB0r9nLlOAgJgCBMQUICCmAAExBQiIKUDAo1GwIb4cZb9cpgABMQUIiClAQEwBAmIKELDm\nww6M4zj5ur++2+EyBQiIKUBATAECYgoQEFOAwGXtNwD83dxq77P82+EyBQiIKUBATAECYgoQEFOA\ngM/mwwH5LP/yXKYAATEFCIgpQEBMAQJiChDw2Xw4IJ/lX57LFCAgpgABMQUIiClAQEwBAj6bD1j5\nAy5TgICYAgTEFCAgpgABMQUIWPOBWXMr/zBY+v/kMgUIiClAQEwBAmIKEBBTgICYAgQ8GgU8xZej\n/J/LFCAgpgABMQUIiClAQEwBAtZ8IHU+nydfv16vC7+TZblMAQJiChAQU4CAmAIExBQgYM0HFnH0\nld9lChAQU4CAmAIExBQgIKYAAWs+sKqjrPwuU4CAmAIExBQgIKYAATEFCFjzgU06naZvvdvttvA7\n+RmXKUBATAECYgoQEFOAgJgCBKz5wK5sdeV3mQIExBQgIKYAATEFCIgpQMCaDxzCOI6zv7ZE5lym\nAAExBQiIKUBATAECYgoQEFOAwGXtNwBQePT40xJfjuIyBQiIKUBATAECYgoQEFOAgC86Ad5WufK7\nTAECYgoQEFOAgJgCBMQUIGDNB/jDMyu/yxQgIKYAATEFCIgpQEBMAQLWfICAyxQgIKYAATEFCIgp\nQEBMAQJiChD4DwxBWetKS+n3AAAAAElFTkSuQmCC\n", | |
"text/plain": [ | |
"<matplotlib.figure.Figure at 0x7f51febc4160>" | |
] | |
}, | |
"metadata": { | |
"tags": [] | |
} | |
} | |
] | |
}, | |
{ | |
"metadata": { | |
"id": "5aj32luX7_zw", | |
"colab_type": "code", | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 187 | |
}, | |
"outputId": "a20cf0e2-aaa4-4599-cdcb-144a51426b36" | |
}, | |
"cell_type": "code", | |
"source": [ | |
"mapper.init_hx" | |
], | |
"execution_count": 14, | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": [ | |
"Parameter containing:\n", | |
"tensor([[-0.9590, -0.0698, 1.4637, 0.9840, 0.9411, 0.5866, 0.3235, 0.8579,\n", | |
" -0.5576, 0.8684, -1.4785, 1.1317, 1.1024, -0.5567, 0.4271, -0.0969,\n", | |
" 1.3948, -0.0274, -0.5441, -1.0963, -0.9247, 2.6980, -0.5265, -0.4226,\n", | |
" 1.1300, -0.2690, 0.1382, -0.5312, -1.1759, 0.5948, 0.8446, 1.4066,\n", | |
" -1.7117, -0.6705, 1.8127, -2.1777, -0.2969, -0.5643, 0.4175, 1.6721,\n", | |
" -0.2167, -2.0209, 0.4257, -0.4736, 0.0445, -0.6640, 0.6513, 0.9604,\n", | |
" -2.4647, -1.7136, 1.3164, -0.1608, -0.4876, 0.6494, 0.2802, 0.0866,\n", | |
" 1.7351, -2.7250, -0.7138, 0.0080, -1.4024, 0.4650, 0.9506, 0.3659]],\n", | |
" device='cuda:0', requires_grad=True)" | |
] | |
}, | |
"metadata": { | |
"tags": [] | |
}, | |
"execution_count": 14 | |
} | |
] | |
} | |
] | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment