Created
February 20, 2018 20:43
-
-
Save arunmallya/34524996c5c5246e0106cd05743af5d1 to your computer and use it in GitHub Desktop.
Convolution with masking support.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class ElementWiseConv2d(nn.Module): | |
"""Modified conv. Do we need mask for biases too?""" | |
def __init__(self, in_channels, out_channels, kernel_size, stride=1, | |
padding=0, dilation=1, groups=1, bias=True, | |
mask_init='1s', mask_scale=1e-2, | |
threshold_fn='binarizer', threshold=None): | |
super(ElementWiseConv2d, self).__init__() | |
kernel_size = _pair(kernel_size) | |
stride = _pair(stride) | |
padding = _pair(padding) | |
dilation = _pair(dilation) | |
self.threshold_fn = threshold_fn | |
self.mask_scale = mask_scale | |
self.mask_init = mask_init | |
if in_channels % groups != 0: | |
raise ValueError('in_channels must be divisible by groups') | |
if out_channels % groups != 0: | |
raise ValueError('out_channels must be divisible by groups') | |
self.in_channels = in_channels | |
self.out_channels = out_channels | |
self.kernel_size = kernel_size | |
self.stride = stride | |
self.padding = padding | |
self.dilation = dilation | |
self.transposed = False | |
self.output_padding = _pair(0) | |
self.groups = groups | |
# weight and bias are no longer Parameters. | |
self.weight = Variable(torch.Tensor( | |
out_channels, in_channels // groups, *kernel_size), requires_grad=False) | |
if bias: | |
self.bias = Variable(torch.Tensor( | |
out_channels), requires_grad=False) | |
else: | |
self.register_parameter('bias', None) | |
# Initialize real-valued mask with weights. | |
self.mask_real = self.weight.data.new(self.weight.size()) | |
if mask_init == '1s': | |
self.mask_real.fill_(mask_scale) | |
elif mask_init == 'uniform': | |
self.mask_real.uniform_(-1 * mask_scale, mask_scale) | |
self.mask_real = Parameter(self.mask_real) | |
# Initialize the thresholder. | |
if threshold_fn == 'binarizer': | |
if threshold is None: | |
threshold = DEFAULT_THRESHOLD | |
print('Calling binarizer with threshold:', threshold) | |
self.threshold_fn = Binarizer(threshold=threshold) | |
elif threshold_fn == 'ternarizer': | |
if threshold is None: | |
threshold = DEFAULT_THRESHOLD | |
print('Calling ternarizer with threshold:', threshold) | |
self.threshold_fn = Ternarizer(threshold=threshold) | |
def forward(self, input): | |
mask_thresholded = self.threshold_fn(self.mask_real) | |
weight_thresholded = mask_thresholded * self.weight | |
return F.conv2d(input, weight_thresholded, self.bias, self.stride, | |
self.padding, self.dilation, self.groups) | |
def __repr__(self): | |
s = ('{name} ({in_channels}, {out_channels}, kernel_size={kernel_size}' | |
', stride={stride}') | |
if self.padding != (0,) * len(self.padding): | |
s += ', padding={padding}' | |
if self.dilation != (1,) * len(self.dilation): | |
s += ', dilation={dilation}' | |
if self.output_padding != (0,) * len(self.output_padding): | |
s += ', output_padding={output_padding}' | |
if self.groups != 1: | |
s += ', groups={groups}' | |
if self.bias is None: | |
s += ', bias=False' | |
s += ')' | |
return s.format(name=self.__class__.__name__, **self.__dict__) | |
def _apply(self, fn): | |
for module in self.children(): | |
module._apply(fn) | |
for param in self._parameters.values(): | |
if param is not None: | |
# Variables stored in modules are graph leaves, and we don't | |
# want to create copy nodes, so we have to unpack the data. | |
param.data = fn(param.data) | |
if param._grad is not None: | |
param._grad.data = fn(param._grad.data) | |
for key, buf in self._buffers.items(): | |
if buf is not None: | |
self._buffers[key] = fn(buf) | |
self.weight.data = fn(self.weight.data) | |
if self.bias is not None and self.bias.data is not None: | |
self.bias.data = fn(self.bias.data) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment