Skip to content

Instantly share code, notes, and snippets.

@BarclayII
Last active December 1, 2020 18:00
Show Gist options
  • Save BarclayII/db424cd39828392646d113a8edb8cfa6 to your computer and use it in GitHub Desktop.
Save BarclayII/db424cd39828392646d113a8edb8cfa6 to your computer and use it in GitHub Desktop.
Differentiable Neural Computer
#!/usr/bin/env python
# coding: utf-8
import torch
import torch.nn as nn
import torch.nn.functional as F
import tqdm
def reverse_permute(x):
return torch.zeros_like(x).scatter_(
1, x, torch.arange(x.shape[1], device=x.device).view(1, -1).expand_as(x))
def oneplus(x):
return 1 + F.softplus(x)
def attention(K, q, beta):
"""
K: [B, N, W]
q: [B, R, W]
beta: [B, R]
return: [B, R, N]
"""
dot = q @ K.transpose(1, 2)
sim = q.norm(dim=2)[:, :, None] * K.norm(dim=2)[:, None, :]
cos = dot / (sim + 1e-4)
logits = (beta[..., None] * cos) # [B, R, N]
return torch.softmax(logits, -1)
class DNCCell(nn.Module):
def __init__(self,
input_size,
output_size,
lstm_hidden_size=64,
num_memory_slots=20,
memory_slot_size=16,
num_read_heads=4,
num_write_heads=1,
clip_value=20):
super().__init__()
self.lstm_hidden_size = lstm_hidden_size
self.output_size = output_size
self.memory_slot_size = memory_slot_size
self.num_read_heads = num_read_heads
self.num_write_heads = num_write_heads
self.num_memory_slots = num_memory_slots
self.lstm_cell = nn.LSTMCell(
input_size + num_read_heads * memory_slot_size, lstm_hidden_size)
interface_size = memory_slot_size * (num_read_heads + 3 * num_write_heads) + \
(3 + 2 * num_write_heads) * num_read_heads + \
3 * num_write_heads
self.controller_output = nn.Linear(lstm_hidden_size, output_size + interface_size)
self.reader_output = nn.Linear(memory_slot_size * num_read_heads, output_size)
self.clip_value = clip_value
def init(self, batch_size, device):
h = torch.zeros(batch_size, self.lstm_hidden_size, device=device)
c = torch.zeros(batch_size, self.lstm_hidden_size, device=device)
r = torch.zeros(batch_size, self.num_read_heads, self.memory_slot_size, device=device)
w_r = torch.zeros(batch_size, self.num_read_heads, self.num_memory_slots, device=device)
w_w = torch.zeros(batch_size, self.num_write_heads, self.num_memory_slots, device=device)
u = torch.zeros(batch_size, self.num_memory_slots, device=device)
M = torch.zeros(batch_size, self.num_memory_slots, self.memory_slot_size, device=device)
p = torch.zeros(batch_size, self.num_write_heads, self.num_memory_slots, device=device)
L = torch.zeros(batch_size, self.num_write_heads, self.num_memory_slots, self.num_memory_slots, device=device)
return {
'lstm_state': (h, c),
'r': r,
'w_r': w_r,
'w_w': w_w,
'u': u,
'M': M,
'p': p,
'L': L}
def forward(self, x, prev_states):
"""
x: float32[B, X] input
prev_states: dict
'lstm_state': LSTM previous state
'r': float32[B, R, W]: previously read content
'w_r': float32[B, R, N]: previously read position
'w_w': float32[B, T, N]: previously written position
'u': float32[B, N]: previous usage vector
'M': float32[B, N, W]: previous memory bank
'p': float32[B, T, N]: previous preference vector
'L': float32[B, T, N, N]: previous temporal memory linkage
return:
y: float32[B, Y] prediction
next_states: dict with the same keys as prev_states
"""
B = x.shape[0] # batch_size
lstm_state = prev_states['lstm_state']
r = prev_states['r'] # [B, R, W]
w_r = prev_states['w_r'] # [B, R, N]
w_w = prev_states['w_w'] # [B, T, N]
u = prev_states['u'] # [B, N]
M = prev_states['M'] # [B, N, W]
p = prev_states['p'] # [B, T, N]
L = prev_states['L'] # [B, T, N, N]
# controller update
h, c = self.lstm_cell(torch.cat([x, r.view(B, -1)], 1), lstm_state)
h = h.clamp(min=-self.clip_value, max=self.clip_value)
c = c.clamp(min=-self.clip_value, max=self.clip_value)
# controller outputs and interfacing
nu_and_xi = self.controller_output(h)
nu, e, v, f, k_w, beta_w, g_a, g_w, pi, k_r, beta_r = nu_and_xi.split(
[self.output_size, self.num_write_heads * self.memory_slot_size, self.num_write_heads * self.memory_slot_size,
self.num_read_heads, self.num_write_heads * self.memory_slot_size, self.num_write_heads, self.num_write_heads,
self.num_write_heads, (2 * self.num_write_heads + 1) * self.num_read_heads,
self.num_read_heads * self.memory_slot_size, self.num_read_heads], 1)
e = e.view(B, self.num_write_heads, self.memory_slot_size)
v = v.view(B, self.num_write_heads, self.memory_slot_size)
f = torch.sigmoid(f)
k_w = k_w.view(B, self.num_write_heads, self.memory_slot_size)
beta_w = oneplus(beta_w)
g_a = torch.sigmoid(g_a)[..., None]
g_w = torch.sigmoid(g_w)[..., None]
pi = torch.softmax(pi.view(B, self.num_read_heads, -1), -1)
k_r = k_r.view(B, self.num_read_heads, self.memory_slot_size)
beta_r = oneplus(beta_r)
# shapes:
# nu: [B, Y]
# e: [B, T, W]
# v: [B, T, W]
# f: [B, R]
# k_w: [B, T, W]
# beta_w: [B, T]
# g_a: [B, T, 1]
# g_w: [B, T, 1]
# pi: [B, R, 2 * T + 1]
# k_r: [B, R, W]
# beta_r: [B, R]
# memory usage
psi_r = (1 - f[..., None] * w_r).prod(1) # [B, N]
psi_w = u + (1 - u) * (1 - (1 - w_w).prod(1)) # [B, N]
u = psi_r * psi_w # [B, N]
# memory allocation - allocate for each write head one-by-one
u_i = u
a_list = []
g_a_w = g_w * g_a
for i in range(self.num_write_heads):
u_sort, u_sort_idx = u_i.sort(1)
u_sort_revidx = reverse_permute(u_sort_idx)
u_sort_cumprod = torch.cat([torch.ones(B, 1, device=u.device), u_sort[:, :-1]], 1).cumprod(1)
a_i = ((1 - u_sort) * u_sort_cumprod).gather(1, u_sort_revidx) # [B, N]
u_i += (1 - u_i) * g_a_w[:, i, :] * a_i
a_list.append(a_i)
a = torch.stack(a_list, 1) # [B, T, N]
# memory write
c_w = attention(M, k_w, beta_w) # [B, T, N]
w_w = g_w * (g_a * a + (1 - g_a) * c_w) # [B, T, N]
# freeing has a multiplicative effect
M = M * (1 - w_w[:, :, :, None] * e[:, :, None, :]).prod(1) + w_w.transpose(1, 2) @ v
# temporal memory linkage
L = (1 - w_w[:, :, :, None] - w_w[:, :, None, :]) * L + w_w[:, :, :, None] * p[:, :, None, :]
Ldiag = torch.arange(self.num_memory_slots, device=L.device)
L[:, :, Ldiag, Ldiag] = 0
p = (1 - w_w.sum(2, keepdim=True)) * p + w_w
# memory read
c_r = attention(M, k_r, beta_r) # [B, R, N]
w_r_ex = w_r[:, None, :, :].expand(-1, self.num_write_heads, -1, -1)
b_r = (w_r_ex @ L).transpose(1, 2) # [B, R, T, N]
f_r = (w_r_ex @ L.transpose(2, 3)).transpose(1, 2)
pi_c, pi_b, pi_f = pi.split([1, self.num_write_heads, self.num_write_heads], 2)
w_r = pi_c * c_r + torch.einsum('ijk,ijkl->ijl', pi_b, b_r) + torch.einsum('ijk,ijkl->ijl', pi_f, f_r)
r = w_r @ M # [B, R, W]
# prediction
y = nu + self.reader_output(r.view(B, -1))
y = y.clamp(min=-self.clip_value, max=self.clip_value)
return y, {
'lstm_state': (h, c),
'r': r,
'w_r': w_r,
'w_w': w_w,
'u': u,
'M': M,
'p': p,
'L': L}
class RepeatedCopy(object):
"""
Repeat a null-terminated string n_repeats times.
"""
def __init__(self, n_tokens=10, length_min=5, length_max=10, n_repeats=4, batch_size=4):
self.n_tokens = n_tokens
self.length_min = length_min
self.length_max = length_max
self.n_repeats = n_repeats
self.batch_size = batch_size
def __iter__(self):
while True:
x = torch.zeros(self.batch_size, self.length_max * self.n_repeats + 1, dtype=torch.int64)
x_enc = torch.zeros(self.batch_size, self.length_max * self.n_repeats + 1, self.n_tokens + 1)
y = torch.zeros(self.batch_size, self.length_max * self.n_repeats + 1, dtype=torch.int64)
y_mask = torch.zeros(self.batch_size, self.length_max * self.n_repeats + 1, dtype=torch.bool)
for i in range(self.batch_size):
l = torch.randint(self.length_min, self.length_max + 1, (1,))
sx = torch.randint(0, self.n_tokens, (l,))
sy = sx.repeat(self.n_repeats)
sx = torch.cat([sx, torch.LongTensor([self.n_tokens])]) # null-terminated
sy = torch.cat([sy, torch.LongTensor([self.n_tokens])]) # null-terminated
lx = l + 1
ly = self.n_repeats * l + 1
x_enc[i, :lx].scatter_(1, sx.view(-1, 1), torch.ones(lx, 1))
x[i, :lx] = sx
y[i, :ly] = sy
y_mask[i, :ly] = 1
yield x, x_enc, y, y_mask
def train_model(cell, task, N_ITERS):
opt = torch.optim.Adam(cell.parameters())
with tqdm.trange(N_ITERS) as tq:
for T, (x, x_enc, y, y_mask) in enumerate(task):
if T == N_ITERS:
break
x = x.to(dev)
x_enc = x_enc.to(dev)
y = y.to(dev)
y_mask = y_mask.to(dev)
y_hat = torch.zeros_like(x_enc)
state = cell.init(x.shape[0], x.device)
for i in range(x.shape[1]):
yt_hat, state = cell(x_enc[:, i], state)
y_hat[:, i] = yt_hat
y_mask = y_mask.view(-1)
y = y.view(-1)[y_mask]
y_hat = y_hat.view(-1, task.n_tokens + 1)[y_mask]
y_pred = y_hat.argmax(1)
loss = F.cross_entropy(y_hat, y)
acc = (y_pred == y).float().mean()
opt.zero_grad()
loss.backward()
opt.step()
tq.set_postfix({'loss': '%.03f' % loss.item(), 'acc': '%.03f' % acc.item()}, refresh=False)
tq.update()
def test_model(cell, task, N_TESTS):
elem_acc_sum = 0
elem_acc_count = 0
with tqdm.trange(N_TESTS) as tq, torch.no_grad():
for T, (x, x_enc, y, y_mask) in enumerate(task):
if T == N_TESTS:
break
x = x.to(dev)
x_enc = x_enc.to(dev)
y = y.to(dev)
y_mask = y_mask.to(dev)
y_hat = torch.zeros_like(x_enc)
state = cell.init(x.shape[0], x.device)
for i in range(x.shape[1]):
yt_hat, state = cell(x_enc[:, i], state)
y_hat[:, i] = yt_hat
correct = (y_hat.argmax(-1) == y)
# We measure element-wise accuracy
elem_correct = correct.view(-1)[y_mask.view(-1)].float()
elem_acc = elem_correct.mean()
elem_acc_sum += elem_correct.sum()
elem_acc_count += len(elem_correct)
tq.set_postfix(
{'elem_acc': '%.03f' % elem_acc.item()},
refresh=False)
tq.update()
return elem_acc_sum / elem_acc_count
dev = 'cpu'
task = RepeatedCopy()
test_task = RepeatedCopy(length_min=12, length_max=16, batch_size=20)
cell = DNCCell(task.n_tokens + 1, task.n_tokens + 1).to(dev)
train_model(cell, task, 3500)
print(test_model(cell, test_task, 500))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment