Skip to content

Instantly share code, notes, and snippets.

@crcrpar
Last active May 8, 2019 03:44
Show Gist options
  • Select an option

  • Save crcrpar/65be313f9b8528730939768666dc7e9f to your computer and use it in GitHub Desktop.

Select an option

Save crcrpar/65be313f9b8528730939768666dc7e9f to your computer and use it in GitHub Desktop.
attention augmented convolution in Chainer
import typing
import chainer
from chainer import functions
from chainer import links
from chainer import types
class AttentionAugmentedConvolution2D(chainer.Chain):
def __init__(
self,
in_channels: int,
out_channels: int,
ksize: int,
k_dim: int,
v_dim: int,
n_head: int,
is_relative: bool = False,
stride: int = 1,
pad: int = 0,
) -> None:
super(AttentionAugmentedConvolution2D, self).__init__()
assert n_head > 0
assert k_dim % n_head == 0
assert v_dim % n_head == 0
self.in_channels = in_channels
self.out_channels = out_channels
self.ksize = ksize
self.k_dim = k_dim
self.v_dim = v_dim
self.n_head = n_head
self.is_relative = is_relative
self.stride = stride
self.pad = pad
with self.init_scope():
self.conv_out = links.Convolution2D(
self.in_channels, self.out_channels - self.v_dim, self.ksize,
stride=self.stride, pad=self.pad)
self.qkv_conv = links.Convolution2D(
self.in_channels, 2 * self.k_dim + self.v_dim, self.ksize,
stride=self.stride, pad=self.pad)
self.attn_out = links.Convolution2D(
self.v_dim, self.v_dim, ksize=1, stride=1)
def forward(self, x: typing.Union[chainer.Variable, types.NdArray]) -> chainer.Variable:
assert x.ndim == 4
conv_out = self.conv_out(x)
batch_size, _, height, width = conv_out.shape
flat_q, flat_k, flat_v, q, k, v = self.compute_flat_qkv(x)
logits = functions.matmul(functions.transpose(flat_q, (0, 1, 3, 2)), flat_k)
if self.is_relative:
h_rel_logits, w_rel_logits = self.relative_logits(q)
logits += (h_rel_logits + w_rel_logits)
weights = functions.softmax(logits, axis=-1)
attn_out = functions.matmul(weights, functions.transpose(flat_v, (0, 1, 3, 2)))
attn_out = functions.reshape(attn_out, (batch_size, self.n_head, self.v_dim // self.n_head, height, width))
attn_out = self.combine_heads_2d(attn_out)
attn_out = self.attn_out(attn_out)
return functions.concat((conv_out, attn_out), axis=1)
def compute_flat_qkv(self, x: chainer.Variable) -> typing.Tuple[chainer.Variable, chainer.Variable, chainer.Variable, chainer.Variable, chainer.Variable, chainer.Variable]:
batch_size = len(x)
qkv = self.qkv_conv(x)
q, k, v = functions.split_axis(
qkv, (self.k_dim, 2 * self.k_dim), axis=1)
q = self.split_heads_2d(q)
k = self.split_heads_2d(k)
v = self.split_heads_2d(v)
head_size = self.k_dim // self.n_head
q *= head_size ** -0.5
flat_q = functions.reshape(q, (batch_size, self.n_head, head_size, -1))
flat_k = functions.reshape(k, (batch_size, self.n_head, head_size, -1))
flat_v = functions.reshape(k, (batch_size, self.n_head, self.v_dim // self.n_head, -1))
return flat_q, flat_k, flat_v, q, k, v
def split_heads_2d(self, x: chainer.Variable) -> chainer.Variable:
batch, channels, h, w = x.shape
ret_shape = (batch, self.n_head, channels // self.n_head, h, w)
return functions.reshape(x, ret_shape)
def combine_heads_2d(self, x: chainer.Variable) -> chainer.Variable:
b, nh, size, h, w = x.shape
ret_shape = (b, nh * size, h, w)
return functions.reshape(x, ret_shape)
def relative_logits(self, q: chainer.Variable) -> typing.Tuple[chainer.Variable, chainer.Variable]:
b, nh, size, h, w = q.shape
q = functions.transpose(q, (0, 1, 3, 4, 2))
key_rel_w = self.xp.random.randn(2 * w - 1, size).astype(q.dtype)
rel_logits_w = self.relative_logits_1d(q, key_rel_w, h, w, nh, 'w')
key_rel_h = self.xp.random.randn(2 * h - 1, size)
rel_logits_h = self.relative_logits_1d(functions.transpose(q, (0, 1, 3, 2, 4)), key_rel_h, w, h, nh, 'h')
return rel_logits_h, rel_logits_w
def relative_logits_1d(self, q, rel_k, h, w, nh, case):
rel_logits = functions.einsum('bhxyd,md->bhxym', q, rel_k)
rel_logits = functions.reshape(rel_logits, (-1, nh * h, w, 2 * w - 1))
rel_logits = self.rel2abs(rel_logits)
rel_logits = functions.reshape(rel_logits, (-1, nh, h, w, w))
rel_logits = functions.expand_dims(rel_logits, axis=3)
rel_logits = functions.repeat(rel_logits, (1, 1, 1, h, 1, 1))
if case == 'w':
rel_logits = functions.transpose(rel_logits, (0, 1, 2, 4, 3, 5))
else:
rel_logits = functions.transpose(rel_logits, (0, 1, 4, 5, 2, 3))
rel_logits = functions.reshape(rel_logits, (-1, nh, h * w, h * w))
return rel_logits
def rel2abs(self, x: chainer.Variable) -> chainer.Variable:
b, nh, l, _ = x.shape
col_pad = self.xp.zeros((b, nh, l, 1), x.dtype)
x = functions.concat((x, col_pad), axis=3)
flat_x = functions.reshape(x, (b, nh, l * 2 * l))
flat_pad = self.xp.zeros((b, nh, l - 1), x.dtype)
flat_x_padded = functions.concat((flat_x, flat_pad), axis=2)
final_x = functions.reshape(flat_x_padded, (b, nh, l + 1, 2 * l - 1))
final_x = final_x[:, :, :l, l-1:]
return final_x
if __name__ == '__main__':
import numpy as np
print('Check forward computation')
tmp = np.random.randn(16, 3, 32, 32).astype(np.float32)
augmented_conv = AttentionAugmentedConvolution2D(3, 20, 3, 40, 4, 1, True)
conv_out = augmented_conv(tmp)
print(conv_out.shape)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment