Created
October 13, 2017 07:27
-
-
Save t-vi/f3437d31b3e4680cc78d9999ea5a8af6 to your computer and use it in GitHub Desktop.
Computing the Variance of Gradients for Linear Layers
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
from torch.autograd import Variable | |
def linear_with_sumsq(inp, weight, bias=None): | |
def provide_sumsq(inp,w,b): | |
def _h(i): | |
if not hasattr(w, 'grad_sumsq'): | |
w.grad_sumsq = 0 | |
w.grad_sumsq += ((i**2).t().matmul(inp**2))*i.size(0) | |
if b is not None: | |
if not hasattr(b, 'grad_sumsq'): | |
b.grad_sumsq = 0 | |
b.grad_sumsq += (i**2).sum(0)*i.size(0) | |
return _h | |
res = inp.matmul(weight.t()) | |
if bias is not None: | |
res = res + bias | |
res.register_hook(provide_sumsq(inp,weight,bias)) | |
return res | |
weight = Variable(torch.randn(3,2), requires_grad=True) | |
inp = Variable(torch.randn(4,2)) | |
bias = Variable(torch.randn(3), requires_grad=True) | |
c = linear_with_sumsq(inp, weight, bias) | |
d = (c**2).sum(1).mean(0) | |
d.backward() | |
# manual variance calculation | |
gr = [] | |
gr_b = [] | |
for i in range(len(inp)): | |
w_i = Variable(weight.data, requires_grad=True) | |
b_i = Variable(bias.data, requires_grad=True) | |
i_i = inp[i:i+1] | |
c_i = i_i.matmul(w_i.t())+b_i | |
d_i = (c_i**2).sum() | |
d_i.backward() | |
gr.append(w_i.grad.data) | |
gr_b.append(b_i.grad.data) | |
gr = torch.stack(gr, dim=0) | |
gr_b = torch.stack(gr_b, dim=0) | |
print(gr.var(0,unbiased=False), weight.grad_sumsq-weight.grad**2, gr_b.var(0,unbiased=False), bias.grad_sumsq-bias.grad**2) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
The convolution is defined by
$$f_{cij} = w_{cdkl} inp_{d,i+k,j+l}$$
$$d f_{cij} / d w_{cdkl} = inp_{d,i+k,j+l}$$
so the derivative is
and the total derivative of some loss out is
$$
d out / d w_{cdkl} = sum_{ij} d f_{cij} / d w_{cdkl} \cdot d out / d f_{cij}
= sum_{ij} inp_{d,i+k,j+l} \cdot dout/ d f_{cij}
$$
Getting the sum over ij and then square before summing over the batch seems not possible right now.