Created
November 13, 2023 14:36
-
-
Save strayge/d0c48ebd03a776c43cd68810bd06a85a to your computer and use it in GitHub Desktop.
just testing NN stuff
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "code", | |
"execution_count": 100, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"import torch\n", | |
"import torch.nn as nn\n", | |
"\n", | |
"METAS = [11,12,13,14,15,16,17]\n", | |
"VALUES = [1,2,3,4,5,6,7,8,9]\n", | |
"EOS = 0\n", | |
"META_LEN = 1\n", | |
"VALUES_LEN = 5\n", | |
"\n", | |
"def conv_meta_to_one_hot(meta):\n", | |
" meta_tensor = torch.zeros(1, len(METAS))\n", | |
" pos = METAS.index(meta)\n", | |
" meta_tensor[0][pos] = 1\n", | |
" return meta_tensor\n", | |
"\n", | |
"def conv_values_to_one_hot(values):\n", | |
" values_tensor = torch.zeros(len(values), 1, len(VALUES))\n", | |
" for i, value in enumerate(values):\n", | |
" pos = VALUES.index(value)\n", | |
" values_tensor[i][0][pos] = 1\n", | |
" return values_tensor\n", | |
"\n", | |
"def make_result_tensor(values):\n", | |
" data = []\n", | |
" for i, value in enumerate(values):\n", | |
" data.append(VALUES.index(value))\n", | |
" data.append(EOS)\n", | |
" tensor = torch.LongTensor(data)\n", | |
" return tensor\n", | |
"\n", | |
"\n", | |
"class RNN(nn.Module):\n", | |
" def __init__(self):\n", | |
" super(RNN, self).__init__()\n", | |
" input_size = len(VALUES)\n", | |
" output_size = len(VALUES)\n", | |
" self.hidden_size = 128\n", | |
"\n", | |
" self.i2h = nn.Linear(len(METAS) + input_size + self.hidden_size, self.hidden_size)\n", | |
" self.i2o = nn.Linear(len(METAS) + input_size + self.hidden_size, output_size)\n", | |
" self.o2o = nn.Linear(self.hidden_size + output_size, output_size)\n", | |
" self.dropout = nn.Dropout(0.1)\n", | |
" self.softmax = nn.LogSoftmax(dim=1)\n", | |
"\n", | |
" def forward(self, meta, input, hidden):\n", | |
" input_combined = torch.cat((meta, input, hidden), 1)\n", | |
" hidden = self.i2h(input_combined)\n", | |
" output = self.i2o(input_combined)\n", | |
" output_combined = torch.cat((hidden, output), 1)\n", | |
" output = self.o2o(output_combined)\n", | |
" output = self.dropout(output)\n", | |
" output = self.softmax(output)\n", | |
" return output, hidden\n", | |
"\n", | |
" def initHidden(self):\n", | |
" return torch.zeros(1, self.hidden_size)\n", | |
"\n", | |
" def train(self, meta_tensor, input_tensor, target_tensor, learning_rate=0.0005):\n", | |
" criterion = nn.NLLLoss()\n", | |
"\n", | |
" target_tensor.unsqueeze_(-1)\n", | |
" hidden = self.initHidden()\n", | |
"\n", | |
" self.zero_grad()\n", | |
" loss = torch.Tensor([0])\n", | |
"\n", | |
" for i in range(input_tensor.size(0)):\n", | |
" output, hidden = self(meta_tensor, input_tensor[i], hidden)\n", | |
" l = criterion(output, target_tensor[i])\n", | |
" loss += l\n", | |
"\n", | |
" loss.backward()\n", | |
"\n", | |
" for p in self.parameters():\n", | |
" p.data.add_(p.grad.data, alpha=-learning_rate)\n", | |
"\n", | |
" return output, loss.item() / input_tensor.size(0)\n", | |
"\n", | |
" def sample(self, meta, start_value=1):\n", | |
" with torch.no_grad():\n", | |
" meta_tensor = conv_meta_to_one_hot(meta)\n", | |
" input = conv_values_to_one_hot([start_value])\n", | |
" hidden = self.initHidden()\n", | |
"\n", | |
" output_name = start_value\n", | |
"\n", | |
" for i in range(VALUES_LEN):\n", | |
" output, hidden = self(meta_tensor, input[0], hidden)\n", | |
" topv, topi = output.topk(1)\n", | |
" topi = topi[0][0]\n", | |
" print(topi)\n", | |
" if topi == len(VALUES) - 1:\n", | |
" break\n", | |
" else:\n", | |
" value = VALUES[topi]\n", | |
" output_name += value\n", | |
" input = conv_values_to_one_hot([value])\n", | |
"\n", | |
" return output_name" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 103, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"tensor(4)\n", | |
"tensor(4)\n", | |
"tensor(4)\n", | |
"tensor(4)\n", | |
"tensor(4)\n", | |
"sample: 26\n" | |
] | |
} | |
], | |
"source": [ | |
"from random import randint\n", | |
"\n", | |
"\n", | |
"data = []\n", | |
"for i in range(1000):\n", | |
" j = randint(1, 7)\n", | |
" meta = 10 + j\n", | |
" data.append(([meta], [j, j, j, j+1, j+2]))\n", | |
"\n", | |
"rnn = RNN()\n", | |
"\n", | |
"for iter in range(0, 1000):\n", | |
" meta, values = data[iter]\n", | |
" meta_tensor = conv_meta_to_one_hot(meta[0])\n", | |
" values_tensor = conv_values_to_one_hot(values)\n", | |
" result_tensor = make_result_tensor(values)\n", | |
" rnn.train(meta_tensor, values_tensor, result_tensor)\n", | |
"\n", | |
"print('sample:', rnn.sample(13))\n" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "venv", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.11.6" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 2 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment