Skip to content

Instantly share code, notes, and snippets.

@grafi-tt
Last active December 20, 2018 18:13
Show Gist options
  • Save grafi-tt/5c7d80e7a8c37e77cc32b4f8aca01411 to your computer and use it in GitHub Desktop.
Save grafi-tt/5c7d80e7a8c37e77cc32b4f8aca01411 to your computer and use it in GitHub Desktop.
# Modification of Chainer's code
# See https://github.com/chainer/chainer/blob/master/LICENSE
import numpy
import chainer
from chainer import backend
from chainer.backends import cuda
from chainer import configuration
from chainer import function_node
from chainer.utils import type_check
if cuda.cudnn_enabled:
cudnn = cuda.cudnn
libcudnn = cuda.cuda.cudnn
class GroupNormalization(function_node.FunctionNode):
mean = None
inv_std = None
dummy_gamma = None
def __init__(self, groups, eps=1e-5):
if not isinstance(groups, int):
raise TypeError('Argument: \'groups\' type must be (int).')
self.groups = groups
self.eps = eps
def check_type_forward(self, in_types):
type_check.expect(in_types.size() == 3)
x_type, gamma_type, beta_type = in_types
type_check.expect(
x_type.dtype.kind == 'f',
x_type.ndim >= 2,
gamma_type.ndim == 1,
beta_type.ndim == 1,
gamma_type.dtype == x_type.dtype,
beta_type.dtype == x_type.dtype,
x_type.shape[1] % self.groups == 0,
x_type.shape[1] == gamma_type.shape[0],
gamma_type.shape == beta_type.shape,
)
def forward(self, inputs):
xp = backend.get_array_module(*inputs)
if xp is not numpy and chainer.should_use_cudnn('>=auto', 5000):
self.forward_cudnn(inputs)
self.retain_inputs((0, 1))
x, gamma, beta = inputs
orig_shape = x.shape
batch_size, channels = orig_shape[:2]
groups = self.groups
x = x.reshape((batch_size * groups, -1))
self.mean = x.mean(axis=1)
x_hat = x - self.mean[:, None]
var = (x_hat * x_hat).mean(axis=1)
var += self.eps
self.inv_std = var
del var
xp.sqrt(self.inv_std, out=self.inv_std)
xp.reciprocal(self.inv_std, out=self.inv_std)
x_hat *= self.inv_std[:, None]
y = x_hat.reshape((batch_size, channels, -1))
y *= gamma[:, None]
y += beta[:, None]
y = y.reshape(orig_shape)
return y,
def forward_cudnn(self, inputs):
if self.eps < libcudnn.CUDNN_BN_MIN_EPSILON:
raise RuntimeError(
'cuDNN does not allow an eps value '
'less than {}.'.format(libcudnn.CUDNN_BN_MIN_EPSILON))
self.retain_inputs((0, 1))
x, gamma, beta = inputs
xp = cuda.cupy
orig_shape = x.shape
batch_size, channels = orig_shape[:2]
groups = self.groups
x = x.reshape((1, batch_size * groups, -1, 1))
with cuda.get_device_from_array(x):
dummy_beta = xp.ones(batch_size * groups, dtype=x.dtype)
self.dummy_gamma = xp.zeros_like(dummy_beta)
self.mean = xp.empty_like(dummy_beta)
self.inv_std = xp.empty_like(dummy_beta)
x_hat = cudnn.batch_normalization_forward_training(
x, self.dummy_gamma, dummy_beta, dummy_beta, dummy_beta,
self.mean, self.inv_std, self.eps, 1.0,
True, libcudnn.CUDNN_BATCHNORM_SPATIAL,
configuration.config.debug)
y = x_hat.reshape((batch_size, channels, -1))
cuda.elementwise(
'T gamma, T beta', 'T y',
'y = y * gamma + beta',
'groupnorm_y')(gamma[:, None], beta[:, None], y)
y = y.reshape(orig_shape)
return y,
def backward(self, indexes, grad_outputs):
x, gamma = self.get_retained_inputs()
gy, = grad_outputs
orig_shape = x.shape
batch_size = orig_shape[0]
groups = self.groups
x = chainer.functions.reshape(x, (batch_size * groups, -1))
x_hat, = _XHat(
self.eps, self.mean, self.inv_std,
self.dummy_gamma).apply((x,))
gx_hat, ggamma, gbeta = _GradHelper().apply((x_hat, gamma, gy))
gx, = _XHatGrad(
self.eps, self.mean, self.inv_std,
self.dummy_gamma, x_hat.array).apply((x, gx_hat))
gx = gx.reshape(orig_shape)
return gx, ggamma, gbeta
class _GradHelper(function_node.FunctionNode):
def forward(self, inputs):
self.retain_inputs((0, 1, 2))
x_hat, gamma, gy = inputs
xp = backend.get_array_module(x_hat)
x_hat_shape = x_hat.shape
batch_size, channels = gy.shape[:2]
x_hat = x_hat.reshape((batch_size, channels, -1))
gy = gy.reshape((batch_size, channels, -1))
gx_hat = gy * gamma[:, None]
if xp is numpy:
ggamma = (gy * x_hat).sum(axis=(0, 2))
else:
ggamma = cuda.reduce(
'T gy, T x_hat', 'T ggamma',
'gy * x_hat', 'a + b', 'ggamma = a', '0',
'groupnorm_ggamma')(gy, x_hat, axis=(0, 2))
gbeta = gy.sum(axis=(0, 2))
gx_hat = gx_hat.reshape(x_hat_shape)
return gx_hat, ggamma, gbeta
def backward(self, indexes, grad_outputs):
x_hat, gamma, gy = self.get_retained_inputs()
ggx_hat, gggamma, ggbeta = grad_outputs
x_hat_shape = x_hat.shape
orig_shape = gy.shape
batch_size, channels = gy.shape[:2]
x_hat = x_hat.reshape((batch_size, channels, -1))
gy = gy.reshape((batch_size, channels, -1))
ggx_hat = ggx_hat.reshape((batch_size, channels, -1))
gx_hat2 = gggamma[:, None] * gy
ggamma2 = chainer.functions.sum(ggx_hat * gy, axis=(0, 2))
ggy = (ggx_hat * gamma[:, None] + gggamma[:, None] * x_hat +
ggbeta[:, None])
gx_hat2 = chainer.functions.reshape(gx_hat2, x_hat_shape)
ggy = chainer.functions.reshape(ggy, orig_shape)
return gx_hat2, ggamma2, ggy
class _XHat(function_node.FunctionNode):
def __init__(self, eps, mean, inv_std, dummy_gamma):
self.eps = eps
self.mean = mean
self.inv_std = inv_std
self.dummy_gamma = dummy_gamma
def forward_cpu(self, inputs):
self.retain_inputs((0,))
x, = inputs
x_hat = x - self.mean[:, None]
x_hat *= self.inv_std[:, None]
self.retain_outputs((0,))
return x_hat,
def forward_gpu(self, inputs):
self.retain_inputs((0,))
x, = inputs
x_hat = cuda.elementwise(
'T x, T mean, T inv_std', 'T x_hat',
'x_hat = (x - mean) * inv_std',
'groupnorm_x_hat')(x, self.mean[:, None], self.inv_std[:, None])
self.retain_outputs((0,))
return x_hat,
def backward(self, indexes, grad_outputs):
x, = self.get_retained_inputs()
x_hat, = self.get_retained_outputs()
gx_hat, = grad_outputs
return _XHatGrad(
self.eps, self.mean, self.inv_std,
self.dummy_gamma, x_hat.array).apply((x, gx_hat))
class _XHatGrad(function_node.FunctionNode):
def __init__(self, eps, mean, inv_std, dummy_gamma, x_hat):
self.eps = eps
self.mean = mean
self.inv_std = inv_std
self.dummy_gamma = dummy_gamma
self.x_hat = x_hat
def forward(self, inputs):
xp = backend.get_array_module(*inputs)
if xp is not numpy and chainer.should_use_cudnn('>=auto', 5000):
self.forward_cudnn(inputs)
self.retain_inputs((0, 1))
_, gx_hat = inputs
x_hat = self.x_hat
gx_hat_avg = gx_hat.mean(axis=1, keepdims=True)
gx_hat_x_hat_avg = (gx_hat * x_hat).mean(axis=1, keepdims=True)
gx_std = gx_hat - gx_hat_avg - x_hat * gx_hat_x_hat_avg
gx = self.inv_std[:, None] * gx_std
self.retain_outputs((0,))
return gx,
def forward_cudnn(self, inputs):
if self.eps < libcudnn.CUDNN_BN_MIN_EPSILON:
raise RuntimeError(
'cuDNN does not allow an eps value '
'less than {}.'.format(libcudnn.CUDNN_BN_MIN_EPSILON))
self.retain_inputs((0, 1))
x, gx_hat = inputs
gx, _, _ = cudnn.batch_normalization_backward(
x[None, :, :, None], self.dummy_gamma, gx_hat[None, :, :, None],
self.mean, self.inv_std, self.eps,
True, libcudnn.CUDNN_BATCHNORM_SPATIAL,
configuration.config.debug)
gx = gx.reshape(x.shape)
self.retain_outputs((0,))
return gx,
def backward(self, indexes, grad_outputs):
F = chainer.functions
x, gx_hat = self.get_retained_inputs()
gx, = self.get_retained_outputs()
ggx, = grad_outputs
ret = []
if 0 in indexes:
# We need differentiable x_hat Variable here.
x_hat, = _XHat(
self.eps, self.mean, self.inv_std,
self.dummy_gamma).apply((x,))
# gx = inv_std * gx_std
# dgx = dinv_std * gx_std + inv_std * dgx_std
# -gx2l = (ggx * dinv_std * gx_std) / dx
# -gx_hat2r = (inv_std * ggx * dgx_std) / dx_hat
gx2l_std = x_hat * F.mean(ggx * gx, axis=1, keepdims=True)
gx2l, = _MulInvStd(
self.eps, self.mean, self.inv_std,
self.dummy_gamma).apply((x, gx2l_std))
gx_hat2r_std = (
ggx * F.mean(gx_hat * x_hat, axis=1, keepdims=True) +
gx_hat * F.mean(ggx * x_hat, axis=1, keepdims=True))
gx_hat2r, = _MulInvStd(
self.eps, self.mean, self.inv_std,
self.dummy_gamma).apply((x, gx_hat2r_std))
gx2r, = _XHatGrad(
self.eps, self.mean, self.inv_std,
self.dummy_gamma, self.x_hat).apply((x, gx_hat2r))
gx2 = -(gx2l + gx2r)
ret.append(gx2)
if 1 in indexes:
ggx_hat, = _XHatGrad(
self.eps, self.mean, self.inv_std,
self.dummy_gamma, self.x_hat).apply((x, ggx))
ret.append(ggx_hat)
return ret
class _MulInvStd(function_node.FunctionNode):
def __init__(self, eps, mean, inv_std, dummy_gamma):
self.eps = eps
self.mean = mean
self.inv_std = inv_std
self.dummy_gamma = dummy_gamma
def forward(self, inputs):
self.retain_inputs((0,))
_, y = inputs
z = self.inv_std[:, None] * y
self.retain_outputs((0,))
return z,
def backward(self, indexes, grad_outputs):
x, = self.get_retained_inputs()
z, = self.get_retained_outputs()
gz, = grad_outputs
x_hat, = _XHat(
self.eps, self.mean, self.inv_std,
self.dummy_gamma).apply((x,))
gx_std = x_hat * chainer.functions.mean(gz * z, axis=1, keepdims=True)
gx, = _MulInvStd(
self.eps, self.mean, self.inv_std,
self.dummy_gamma).apply((x, gx_std))
gy, = _MulInvStd(
self.eps, self.mean, self.inv_std,
self.dummy_gamma).apply((x, gz))
return gx, gy
def group_normalization(x, groups, gamma, beta, eps=1e-5):
"""Group normalization function.
This function implements a "group normalization"
which divides the channels into groups and computes within each group
the mean and variance, then normalize by these statistics,
scales and shifts them.
Args:
x (:class:`~chainer.Variable` or :class:`numpy.ndarray` or \
:class:`cupy.ndarray`): Batch tensors.
First dimension of this value must be the size of minibatch and
second dimension must be the number of channels.
Moreover, this value must have one or more following dimensions,
such as height and width.
groups (int):
The number of channel groups.
This value must be a divisor of the number of channels.
gamma (~chainer.Variable): Scaling parameter.
beta (~chainer.Variable): Shifting parameter.
eps (float): Epsilon value for numerical stability of normalization.
Returns:
~chainer.Variable: The output variable which has the same shape
as :math:`x`.
See: `Group Normalization <https://arxiv.org/abs/1803.08494>`_
"""
return GroupNormalization(groups, eps).apply((x, gamma, beta))[0]
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment