Created
March 30, 2022 02:28
-
-
Save shawntan/b6b28b3f16d54f5e5f1668026f70cabc to your computer and use it in GitHub Desktop.
Parity problem with PFSA
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"id": "0e842850", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"import torch\n", | |
"from torch import nn\n", | |
"import random\n", | |
"import numpy as np" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 2, | |
"id": "d0ae549a", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"# Bit string with even no. of ones\n", | |
"def generate_parity(generate_even=True):\n", | |
" end_state = \"EVEN\" if generate_even else \"ODD\"\n", | |
" \n", | |
" rn = random.random()\n", | |
" if rn < 0.5:\n", | |
" result = [0]\n", | |
" state = \"EVEN\"\n", | |
" else:\n", | |
" result = [1]\n", | |
" state = \"ODD\"\n", | |
" \n", | |
" while True:\n", | |
" if state == end_state:\n", | |
" rn = random.random()\n", | |
" if rn < 0.1:\n", | |
" break\n", | |
" \n", | |
" rn = random.random()\n", | |
" if state == \"EVEN\":\n", | |
" if rn < 0.5:\n", | |
" result.append(1)\n", | |
" state = \"ODD\"\n", | |
" else:\n", | |
" result.append(0)\n", | |
" state = \"EVEN\"\n", | |
" else: # state == \"ODD\"\n", | |
" if rn < 0.5:\n", | |
" result.append(1)\n", | |
" state = \"EVEN\"\n", | |
" else: \n", | |
" result.append(0)\n", | |
" state = \"ODD\"\n", | |
" return result" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 3, | |
"id": "79d6c5dd", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"58 34 0100010100001111111011110110111011100100101101011011010101\n", | |
"29 16 10010000111010110001111111001\n", | |
"1 0 0\n", | |
"8 6 11011011\n", | |
"4 2 1001\n" | |
] | |
} | |
], | |
"source": [ | |
"for i in range(5):\n", | |
" string = generate_parity(generate_even=True)\n", | |
" print(len(string), sum(string), ''.join(str(n) for n in string))\n", | |
" \n", | |
"def create_batch(batch_size, device=torch.device('cpu'), generate_even=True, min_length=0, max_length=100000):\n", | |
" batch_list = []\n", | |
" while len(batch_list) < batch_size:\n", | |
" instance = generate_parity(generate_even) + [2]\n", | |
" if len(instance) >= min_length and len(instance) <= max_length:\n", | |
" batch_list.append(instance)\n", | |
" \n", | |
" max_length = max(len(s) for s in batch_list)\n", | |
" batch = np.full((max_length, batch_size), -1)\n", | |
" for i in range(batch_size):\n", | |
" batch[:, i][:len(batch_list[i])] = batch_list[i]\n", | |
" return torch.tensor(batch)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 4, | |
"id": "b43eba5c", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"def restricted_mask():\n", | |
" transitions = np.zeros((4, 3, 4), dtype=np.bool_)\n", | |
"\n", | |
" transitions[0, 0, 2] = 1\n", | |
" transitions[0, 1, 1] = 1\n", | |
"\n", | |
" transitions[1, 0, 1] = 1\n", | |
" transitions[1, 1, 2] = 1\n", | |
"\n", | |
" transitions[2, 0, 2] = 1\n", | |
" transitions[2, 1, 1] = 1\n", | |
" transitions[2, 2, 3] = 1\n", | |
"\n", | |
" print(transitions.astype(np.int32))\n", | |
" return transitions" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 5, | |
"id": "96dadd7f", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"class PFSA(nn.Module):\n", | |
" \n", | |
" def __init__(self, n_states, n_symbols, end_symbol):\n", | |
" super(PFSA, self).__init__()\n", | |
" self.initial = nn.Parameter(torch.randn(n_states))\n", | |
" self.transition_logits = nn.Parameter(torch.randn(n_states, n_symbols, n_states))\n", | |
" self.end_symbol = end_symbol\n", | |
" \n", | |
" def _normalise(self):\n", | |
" log_init = torch.log_softmax(self.initial, dim=-1)\n", | |
" \n", | |
" z = torch.logsumexp(self.transition_logits, dim=(-2, -1), keepdim=True)\n", | |
" log_probs = self.transition_logits - z\n", | |
" return log_init, log_probs\n", | |
" \n", | |
" def log_mult(self, state, exp_T):\n", | |
" state_k, _ = torch.max(state, dim=-1, keepdim=True)\n", | |
" p_state = torch.exp(state - state_k)\n", | |
" n_state = torch.einsum('bi,bij->bj', p_state, exp_T)\n", | |
" log_n_state = torch.log(n_state) + state_k\n", | |
" return log_n_state\n", | |
" \n", | |
" def forward(self, x):\n", | |
" log_init, log_probs = self._normalise()\n", | |
" log_state = log_init[None, :].expand(x.size(1), -1)\n", | |
" transitions = torch.exp(log_probs[:, x, :]).permute(1, 2, 0, 3)\n", | |
" final = torch.zeros_like(x[0], dtype=torch.float)\n", | |
" for t in range(x.size(0)):\n", | |
" log_state = self.log_mult(log_state, transitions[t])\n", | |
" e = x[t] == self.end_symbol\n", | |
" final[e] = log_state[e, -1] \n", | |
" return final" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"id": "d2191a1e", | |
"metadata": { | |
"scrolled": false | |
}, | |
"outputs": [], | |
"source": [ | |
"pfsa = PFSA(50, 3, end_symbol=2)\n", | |
"\n", | |
"id_test = create_batch(100)\n", | |
"id_length = torch.sum(id_test != -1)\n", | |
"pos_test = create_batch(100, min_length=21)\n", | |
"pos_length = torch.sum(pos_test != -1)\n", | |
"neg_test = create_batch(100, generate_even=False)\n", | |
"neg_length = torch.sum(neg_test != -1)\n", | |
"\n", | |
"optimizer = torch.optim.Adam(pfsa.parameters(), lr=1e-3)\n", | |
"log = []\n", | |
"for i in range(20000):\n", | |
" x = create_batch(100, max_length=20)\n", | |
" lengths = torch.sum(x != -1)\n", | |
" log_prob = pfsa(x)\n", | |
" loss = -log_prob.sum() / lengths\n", | |
" \n", | |
" loss.backward()\n", | |
" optimizer.step()\n", | |
" optimizer.zero_grad()\n", | |
" if i % 200 == 0:\n", | |
" pfsa.eval()\n", | |
" id_test_loss = -pfsa(id_test).sum() / id_length\n", | |
" pos_test_loss = -pfsa(pos_test).sum() / pos_length\n", | |
" neg_test_loss = -pfsa(neg_test).sum() / neg_length\n", | |
" log.append((id_test_loss.item(), pos_test_loss.item(), neg_test_loss.item()))\n", | |
" pfsa.train()\n" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"id": "a362d057", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"import matplotlib.pyplot as plt\n", | |
"\n", | |
"t = list(range(len(log)))\n", | |
"plt.figure(figsize=(10,5))\n", | |
"plt.plot(t, [x[0] for x in log], label='In-distribution ($\\leq$ 20)')\n", | |
"plt.plot(t, [x[1] for x in log], label='OOD ($>$ 20)')\n", | |
"plt.plot(t, [x[2] for x in log], label='Non-even bits')\n", | |
"plt.ylabel('Loss')\n", | |
"plt.xlabel('Iterations')\n", | |
"plt.legend()" | |
] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 3", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.9.12" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 5 | |
} | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
WOW