Last active
December 1, 2020 18:00
-
-
Save BarclayII/db424cd39828392646d113a8edb8cfa6 to your computer and use it in GitHub Desktop.
Differentiable Neural Computer
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python | |
# coding: utf-8 | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
import tqdm | |
def reverse_permute(x): | |
return torch.zeros_like(x).scatter_( | |
1, x, torch.arange(x.shape[1], device=x.device).view(1, -1).expand_as(x)) | |
def oneplus(x): | |
return 1 + F.softplus(x) | |
def attention(K, q, beta): | |
""" | |
K: [B, N, W] | |
q: [B, R, W] | |
beta: [B, R] | |
return: [B, R, N] | |
""" | |
dot = q @ K.transpose(1, 2) | |
sim = q.norm(dim=2)[:, :, None] * K.norm(dim=2)[:, None, :] | |
cos = dot / (sim + 1e-4) | |
logits = (beta[..., None] * cos) # [B, R, N] | |
return torch.softmax(logits, -1) | |
class DNCCell(nn.Module): | |
def __init__(self, | |
input_size, | |
output_size, | |
lstm_hidden_size=64, | |
num_memory_slots=20, | |
memory_slot_size=16, | |
num_read_heads=4, | |
num_write_heads=1, | |
clip_value=20): | |
super().__init__() | |
self.lstm_hidden_size = lstm_hidden_size | |
self.output_size = output_size | |
self.memory_slot_size = memory_slot_size | |
self.num_read_heads = num_read_heads | |
self.num_write_heads = num_write_heads | |
self.num_memory_slots = num_memory_slots | |
self.lstm_cell = nn.LSTMCell( | |
input_size + num_read_heads * memory_slot_size, lstm_hidden_size) | |
interface_size = memory_slot_size * (num_read_heads + 3 * num_write_heads) + \ | |
(3 + 2 * num_write_heads) * num_read_heads + \ | |
3 * num_write_heads | |
self.controller_output = nn.Linear(lstm_hidden_size, output_size + interface_size) | |
self.reader_output = nn.Linear(memory_slot_size * num_read_heads, output_size) | |
self.clip_value = clip_value | |
def init(self, batch_size, device): | |
h = torch.zeros(batch_size, self.lstm_hidden_size, device=device) | |
c = torch.zeros(batch_size, self.lstm_hidden_size, device=device) | |
r = torch.zeros(batch_size, self.num_read_heads, self.memory_slot_size, device=device) | |
w_r = torch.zeros(batch_size, self.num_read_heads, self.num_memory_slots, device=device) | |
w_w = torch.zeros(batch_size, self.num_write_heads, self.num_memory_slots, device=device) | |
u = torch.zeros(batch_size, self.num_memory_slots, device=device) | |
M = torch.zeros(batch_size, self.num_memory_slots, self.memory_slot_size, device=device) | |
p = torch.zeros(batch_size, self.num_write_heads, self.num_memory_slots, device=device) | |
L = torch.zeros(batch_size, self.num_write_heads, self.num_memory_slots, self.num_memory_slots, device=device) | |
return { | |
'lstm_state': (h, c), | |
'r': r, | |
'w_r': w_r, | |
'w_w': w_w, | |
'u': u, | |
'M': M, | |
'p': p, | |
'L': L} | |
def forward(self, x, prev_states): | |
""" | |
x: float32[B, X] input | |
prev_states: dict | |
'lstm_state': LSTM previous state | |
'r': float32[B, R, W]: previously read content | |
'w_r': float32[B, R, N]: previously read position | |
'w_w': float32[B, T, N]: previously written position | |
'u': float32[B, N]: previous usage vector | |
'M': float32[B, N, W]: previous memory bank | |
'p': float32[B, T, N]: previous preference vector | |
'L': float32[B, T, N, N]: previous temporal memory linkage | |
return: | |
y: float32[B, Y] prediction | |
next_states: dict with the same keys as prev_states | |
""" | |
B = x.shape[0] # batch_size | |
lstm_state = prev_states['lstm_state'] | |
r = prev_states['r'] # [B, R, W] | |
w_r = prev_states['w_r'] # [B, R, N] | |
w_w = prev_states['w_w'] # [B, T, N] | |
u = prev_states['u'] # [B, N] | |
M = prev_states['M'] # [B, N, W] | |
p = prev_states['p'] # [B, T, N] | |
L = prev_states['L'] # [B, T, N, N] | |
# controller update | |
h, c = self.lstm_cell(torch.cat([x, r.view(B, -1)], 1), lstm_state) | |
h = h.clamp(min=-self.clip_value, max=self.clip_value) | |
c = c.clamp(min=-self.clip_value, max=self.clip_value) | |
# controller outputs and interfacing | |
nu_and_xi = self.controller_output(h) | |
nu, e, v, f, k_w, beta_w, g_a, g_w, pi, k_r, beta_r = nu_and_xi.split( | |
[self.output_size, self.num_write_heads * self.memory_slot_size, self.num_write_heads * self.memory_slot_size, | |
self.num_read_heads, self.num_write_heads * self.memory_slot_size, self.num_write_heads, self.num_write_heads, | |
self.num_write_heads, (2 * self.num_write_heads + 1) * self.num_read_heads, | |
self.num_read_heads * self.memory_slot_size, self.num_read_heads], 1) | |
e = e.view(B, self.num_write_heads, self.memory_slot_size) | |
v = v.view(B, self.num_write_heads, self.memory_slot_size) | |
f = torch.sigmoid(f) | |
k_w = k_w.view(B, self.num_write_heads, self.memory_slot_size) | |
beta_w = oneplus(beta_w) | |
g_a = torch.sigmoid(g_a)[..., None] | |
g_w = torch.sigmoid(g_w)[..., None] | |
pi = torch.softmax(pi.view(B, self.num_read_heads, -1), -1) | |
k_r = k_r.view(B, self.num_read_heads, self.memory_slot_size) | |
beta_r = oneplus(beta_r) | |
# shapes: | |
# nu: [B, Y] | |
# e: [B, T, W] | |
# v: [B, T, W] | |
# f: [B, R] | |
# k_w: [B, T, W] | |
# beta_w: [B, T] | |
# g_a: [B, T, 1] | |
# g_w: [B, T, 1] | |
# pi: [B, R, 2 * T + 1] | |
# k_r: [B, R, W] | |
# beta_r: [B, R] | |
# memory usage | |
psi_r = (1 - f[..., None] * w_r).prod(1) # [B, N] | |
psi_w = u + (1 - u) * (1 - (1 - w_w).prod(1)) # [B, N] | |
u = psi_r * psi_w # [B, N] | |
# memory allocation - allocate for each write head one-by-one | |
u_i = u | |
a_list = [] | |
g_a_w = g_w * g_a | |
for i in range(self.num_write_heads): | |
u_sort, u_sort_idx = u_i.sort(1) | |
u_sort_revidx = reverse_permute(u_sort_idx) | |
u_sort_cumprod = torch.cat([torch.ones(B, 1, device=u.device), u_sort[:, :-1]], 1).cumprod(1) | |
a_i = ((1 - u_sort) * u_sort_cumprod).gather(1, u_sort_revidx) # [B, N] | |
u_i += (1 - u_i) * g_a_w[:, i, :] * a_i | |
a_list.append(a_i) | |
a = torch.stack(a_list, 1) # [B, T, N] | |
# memory write | |
c_w = attention(M, k_w, beta_w) # [B, T, N] | |
w_w = g_w * (g_a * a + (1 - g_a) * c_w) # [B, T, N] | |
# freeing has a multiplicative effect | |
M = M * (1 - w_w[:, :, :, None] * e[:, :, None, :]).prod(1) + w_w.transpose(1, 2) @ v | |
# temporal memory linkage | |
L = (1 - w_w[:, :, :, None] - w_w[:, :, None, :]) * L + w_w[:, :, :, None] * p[:, :, None, :] | |
Ldiag = torch.arange(self.num_memory_slots, device=L.device) | |
L[:, :, Ldiag, Ldiag] = 0 | |
p = (1 - w_w.sum(2, keepdim=True)) * p + w_w | |
# memory read | |
c_r = attention(M, k_r, beta_r) # [B, R, N] | |
w_r_ex = w_r[:, None, :, :].expand(-1, self.num_write_heads, -1, -1) | |
b_r = (w_r_ex @ L).transpose(1, 2) # [B, R, T, N] | |
f_r = (w_r_ex @ L.transpose(2, 3)).transpose(1, 2) | |
pi_c, pi_b, pi_f = pi.split([1, self.num_write_heads, self.num_write_heads], 2) | |
w_r = pi_c * c_r + torch.einsum('ijk,ijkl->ijl', pi_b, b_r) + torch.einsum('ijk,ijkl->ijl', pi_f, f_r) | |
r = w_r @ M # [B, R, W] | |
# prediction | |
y = nu + self.reader_output(r.view(B, -1)) | |
y = y.clamp(min=-self.clip_value, max=self.clip_value) | |
return y, { | |
'lstm_state': (h, c), | |
'r': r, | |
'w_r': w_r, | |
'w_w': w_w, | |
'u': u, | |
'M': M, | |
'p': p, | |
'L': L} | |
class RepeatedCopy(object): | |
""" | |
Repeat a null-terminated string n_repeats times. | |
""" | |
def __init__(self, n_tokens=10, length_min=5, length_max=10, n_repeats=4, batch_size=4): | |
self.n_tokens = n_tokens | |
self.length_min = length_min | |
self.length_max = length_max | |
self.n_repeats = n_repeats | |
self.batch_size = batch_size | |
def __iter__(self): | |
while True: | |
x = torch.zeros(self.batch_size, self.length_max * self.n_repeats + 1, dtype=torch.int64) | |
x_enc = torch.zeros(self.batch_size, self.length_max * self.n_repeats + 1, self.n_tokens + 1) | |
y = torch.zeros(self.batch_size, self.length_max * self.n_repeats + 1, dtype=torch.int64) | |
y_mask = torch.zeros(self.batch_size, self.length_max * self.n_repeats + 1, dtype=torch.bool) | |
for i in range(self.batch_size): | |
l = torch.randint(self.length_min, self.length_max + 1, (1,)) | |
sx = torch.randint(0, self.n_tokens, (l,)) | |
sy = sx.repeat(self.n_repeats) | |
sx = torch.cat([sx, torch.LongTensor([self.n_tokens])]) # null-terminated | |
sy = torch.cat([sy, torch.LongTensor([self.n_tokens])]) # null-terminated | |
lx = l + 1 | |
ly = self.n_repeats * l + 1 | |
x_enc[i, :lx].scatter_(1, sx.view(-1, 1), torch.ones(lx, 1)) | |
x[i, :lx] = sx | |
y[i, :ly] = sy | |
y_mask[i, :ly] = 1 | |
yield x, x_enc, y, y_mask | |
def train_model(cell, task, N_ITERS): | |
opt = torch.optim.Adam(cell.parameters()) | |
with tqdm.trange(N_ITERS) as tq: | |
for T, (x, x_enc, y, y_mask) in enumerate(task): | |
if T == N_ITERS: | |
break | |
x = x.to(dev) | |
x_enc = x_enc.to(dev) | |
y = y.to(dev) | |
y_mask = y_mask.to(dev) | |
y_hat = torch.zeros_like(x_enc) | |
state = cell.init(x.shape[0], x.device) | |
for i in range(x.shape[1]): | |
yt_hat, state = cell(x_enc[:, i], state) | |
y_hat[:, i] = yt_hat | |
y_mask = y_mask.view(-1) | |
y = y.view(-1)[y_mask] | |
y_hat = y_hat.view(-1, task.n_tokens + 1)[y_mask] | |
y_pred = y_hat.argmax(1) | |
loss = F.cross_entropy(y_hat, y) | |
acc = (y_pred == y).float().mean() | |
opt.zero_grad() | |
loss.backward() | |
opt.step() | |
tq.set_postfix({'loss': '%.03f' % loss.item(), 'acc': '%.03f' % acc.item()}, refresh=False) | |
tq.update() | |
def test_model(cell, task, N_TESTS): | |
elem_acc_sum = 0 | |
elem_acc_count = 0 | |
with tqdm.trange(N_TESTS) as tq, torch.no_grad(): | |
for T, (x, x_enc, y, y_mask) in enumerate(task): | |
if T == N_TESTS: | |
break | |
x = x.to(dev) | |
x_enc = x_enc.to(dev) | |
y = y.to(dev) | |
y_mask = y_mask.to(dev) | |
y_hat = torch.zeros_like(x_enc) | |
state = cell.init(x.shape[0], x.device) | |
for i in range(x.shape[1]): | |
yt_hat, state = cell(x_enc[:, i], state) | |
y_hat[:, i] = yt_hat | |
correct = (y_hat.argmax(-1) == y) | |
# We measure element-wise accuracy | |
elem_correct = correct.view(-1)[y_mask.view(-1)].float() | |
elem_acc = elem_correct.mean() | |
elem_acc_sum += elem_correct.sum() | |
elem_acc_count += len(elem_correct) | |
tq.set_postfix( | |
{'elem_acc': '%.03f' % elem_acc.item()}, | |
refresh=False) | |
tq.update() | |
return elem_acc_sum / elem_acc_count | |
dev = 'cpu' | |
task = RepeatedCopy() | |
test_task = RepeatedCopy(length_min=12, length_max=16, batch_size=20) | |
cell = DNCCell(task.n_tokens + 1, task.n_tokens + 1).to(dev) | |
train_model(cell, task, 3500) | |
print(test_model(cell, test_task, 500)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment