Last active
May 8, 2019 03:44
-
-
Save crcrpar/65be313f9b8528730939768666dc7e9f to your computer and use it in GitHub Desktop.
attention augmented convolution in Chainer
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import typing | |
| import chainer | |
| from chainer import functions | |
| from chainer import links | |
| from chainer import types | |
| class AttentionAugmentedConvolution2D(chainer.Chain): | |
| def __init__( | |
| self, | |
| in_channels: int, | |
| out_channels: int, | |
| ksize: int, | |
| k_dim: int, | |
| v_dim: int, | |
| n_head: int, | |
| is_relative: bool = False, | |
| stride: int = 1, | |
| pad: int = 0, | |
| ) -> None: | |
| super(AttentionAugmentedConvolution2D, self).__init__() | |
| assert n_head > 0 | |
| assert k_dim % n_head == 0 | |
| assert v_dim % n_head == 0 | |
| self.in_channels = in_channels | |
| self.out_channels = out_channels | |
| self.ksize = ksize | |
| self.k_dim = k_dim | |
| self.v_dim = v_dim | |
| self.n_head = n_head | |
| self.is_relative = is_relative | |
| self.stride = stride | |
| self.pad = pad | |
| with self.init_scope(): | |
| self.conv_out = links.Convolution2D( | |
| self.in_channels, self.out_channels - self.v_dim, self.ksize, | |
| stride=self.stride, pad=self.pad) | |
| self.qkv_conv = links.Convolution2D( | |
| self.in_channels, 2 * self.k_dim + self.v_dim, self.ksize, | |
| stride=self.stride, pad=self.pad) | |
| self.attn_out = links.Convolution2D( | |
| self.v_dim, self.v_dim, ksize=1, stride=1) | |
| def forward(self, x: typing.Union[chainer.Variable, types.NdArray]) -> chainer.Variable: | |
| assert x.ndim == 4 | |
| conv_out = self.conv_out(x) | |
| batch_size, _, height, width = conv_out.shape | |
| flat_q, flat_k, flat_v, q, k, v = self.compute_flat_qkv(x) | |
| logits = functions.matmul(functions.transpose(flat_q, (0, 1, 3, 2)), flat_k) | |
| if self.is_relative: | |
| h_rel_logits, w_rel_logits = self.relative_logits(q) | |
| logits += (h_rel_logits + w_rel_logits) | |
| weights = functions.softmax(logits, axis=-1) | |
| attn_out = functions.matmul(weights, functions.transpose(flat_v, (0, 1, 3, 2))) | |
| attn_out = functions.reshape(attn_out, (batch_size, self.n_head, self.v_dim // self.n_head, height, width)) | |
| attn_out = self.combine_heads_2d(attn_out) | |
| attn_out = self.attn_out(attn_out) | |
| return functions.concat((conv_out, attn_out), axis=1) | |
| def compute_flat_qkv(self, x: chainer.Variable) -> typing.Tuple[chainer.Variable, chainer.Variable, chainer.Variable, chainer.Variable, chainer.Variable, chainer.Variable]: | |
| batch_size = len(x) | |
| qkv = self.qkv_conv(x) | |
| q, k, v = functions.split_axis( | |
| qkv, (self.k_dim, 2 * self.k_dim), axis=1) | |
| q = self.split_heads_2d(q) | |
| k = self.split_heads_2d(k) | |
| v = self.split_heads_2d(v) | |
| head_size = self.k_dim // self.n_head | |
| q *= head_size ** -0.5 | |
| flat_q = functions.reshape(q, (batch_size, self.n_head, head_size, -1)) | |
| flat_k = functions.reshape(k, (batch_size, self.n_head, head_size, -1)) | |
| flat_v = functions.reshape(k, (batch_size, self.n_head, self.v_dim // self.n_head, -1)) | |
| return flat_q, flat_k, flat_v, q, k, v | |
| def split_heads_2d(self, x: chainer.Variable) -> chainer.Variable: | |
| batch, channels, h, w = x.shape | |
| ret_shape = (batch, self.n_head, channels // self.n_head, h, w) | |
| return functions.reshape(x, ret_shape) | |
| def combine_heads_2d(self, x: chainer.Variable) -> chainer.Variable: | |
| b, nh, size, h, w = x.shape | |
| ret_shape = (b, nh * size, h, w) | |
| return functions.reshape(x, ret_shape) | |
| def relative_logits(self, q: chainer.Variable) -> typing.Tuple[chainer.Variable, chainer.Variable]: | |
| b, nh, size, h, w = q.shape | |
| q = functions.transpose(q, (0, 1, 3, 4, 2)) | |
| key_rel_w = self.xp.random.randn(2 * w - 1, size).astype(q.dtype) | |
| rel_logits_w = self.relative_logits_1d(q, key_rel_w, h, w, nh, 'w') | |
| key_rel_h = self.xp.random.randn(2 * h - 1, size) | |
| rel_logits_h = self.relative_logits_1d(functions.transpose(q, (0, 1, 3, 2, 4)), key_rel_h, w, h, nh, 'h') | |
| return rel_logits_h, rel_logits_w | |
| def relative_logits_1d(self, q, rel_k, h, w, nh, case): | |
| rel_logits = functions.einsum('bhxyd,md->bhxym', q, rel_k) | |
| rel_logits = functions.reshape(rel_logits, (-1, nh * h, w, 2 * w - 1)) | |
| rel_logits = self.rel2abs(rel_logits) | |
| rel_logits = functions.reshape(rel_logits, (-1, nh, h, w, w)) | |
| rel_logits = functions.expand_dims(rel_logits, axis=3) | |
| rel_logits = functions.repeat(rel_logits, (1, 1, 1, h, 1, 1)) | |
| if case == 'w': | |
| rel_logits = functions.transpose(rel_logits, (0, 1, 2, 4, 3, 5)) | |
| else: | |
| rel_logits = functions.transpose(rel_logits, (0, 1, 4, 5, 2, 3)) | |
| rel_logits = functions.reshape(rel_logits, (-1, nh, h * w, h * w)) | |
| return rel_logits | |
| def rel2abs(self, x: chainer.Variable) -> chainer.Variable: | |
| b, nh, l, _ = x.shape | |
| col_pad = self.xp.zeros((b, nh, l, 1), x.dtype) | |
| x = functions.concat((x, col_pad), axis=3) | |
| flat_x = functions.reshape(x, (b, nh, l * 2 * l)) | |
| flat_pad = self.xp.zeros((b, nh, l - 1), x.dtype) | |
| flat_x_padded = functions.concat((flat_x, flat_pad), axis=2) | |
| final_x = functions.reshape(flat_x_padded, (b, nh, l + 1, 2 * l - 1)) | |
| final_x = final_x[:, :, :l, l-1:] | |
| return final_x | |
| if __name__ == '__main__': | |
| import numpy as np | |
| print('Check forward computation') | |
| tmp = np.random.randn(16, 3, 32, 32).astype(np.float32) | |
| augmented_conv = AttentionAugmentedConvolution2D(3, 20, 3, 40, 4, 1, True) | |
| conv_out = augmented_conv(tmp) | |
| print(conv_out.shape) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment