Created
August 3, 2020 22:49
-
-
Save XinDongol/7662686e5b6f4adf17765ac1a448ceb8 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch.nn as nn | |
import torch.nn.functional as F | |
import torch | |
class DeepInversionFeatureHook(): | |
''' | |
Implementation of the forward hook to track feature statistics and compute a loss on them. | |
Will compute mean and variance, and will use l2 as a loss | |
''' | |
def __init__(self, module): | |
self.hook = module.register_forward_hook(self.hook_fn) | |
def hook_fn(self, module, input, output): | |
# hook co compute deepinversion's feature distribution regularization | |
nch = input[0].shape[1] | |
mean = input[0].mean([0, 2, 3]) | |
var = input[0].permute(1, 0, 2, 3).contiguous().view( | |
[nch, -1]).var(1, unbiased=False) | |
# forcing mean and variance to match between two distributions | |
# other ways might work better, i.g. KL divergence | |
r_feature = torch.norm(module.running_var.data.detach() - var, 2) + torch.norm( | |
module.running_mean.data.detach() - mean, 2) | |
self.r_feature = r_feature | |
# must have no output | |
def close(self): | |
self.hook.remove() | |
class MyNet(nn.Module): | |
def __init__(self): | |
super(MyNet, self).__init__() | |
self.in_planes = 64 | |
self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False) | |
self.bn1 = nn.BatchNorm2d(64) | |
self.conv2 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1, bias=False) | |
self.bn2 = nn.BatchNorm2d(64) | |
self.conv3 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1, bias=False) | |
self.bn3 = nn.BatchNorm2d(64) | |
# regist hooks | |
self.loss_r_feature_layers = [] | |
for module in self.modules(): | |
if isinstance(module, nn.BatchNorm2d): | |
self.loss_r_feature_layers.append( | |
DeepInversionFeatureHook(module)) | |
def forward(self, x): | |
x = self.conv1(x) | |
x = self.bn1(x) | |
x = self.conv2(x) | |
x = self.bn2(x) | |
x = self.conv3(x) | |
x = self.bn3(x) | |
loss_r_feature = sum([mod.r_feature | |
for (idx, mod) in enumerate(self.loss_r_feature_layers)]) | |
return x, loss_r_feature | |
net = nn.DataParallel(MyNet().cuda()) | |
# net = MyNet().cuda() | |
for i in range(10): | |
print('========> %d iteration.' % i) | |
output, extra_loss = net(torch.randn(512, 3, 32, 32).cuda()) | |
print('=> extra_loss:', extra_loss.size(), extra_loss.device, extra_loss) | |
print('=> extra_loss sum:', extra_loss.sum().size(), extra_loss.sum().device, extra_loss.sum()) | |
# loss = F.mse_loss(output, torch.ones_like(output)) | |
# print('=> mse loss:', loss.size(), loss.device, loss) | |
print('=> output:', output.size(), output.device, output.grad_fn) | |
loss = extra_loss.sum() | |
loss.backward() | |
net.zero_grad() |
Author
XinDongol
commented
Aug 3, 2020
I also tried.
import torch.nn as nn
import torch.nn.functional as F
import torch
class DeepInversionFeatureHook():
'''
Implementation of the forward hook to track feature statistics and compute a loss on them.
Will compute mean and variance, and will use l2 as a loss
'''
def __init__(self, module):
self.hook = module.register_forward_hook(self.hook_fn)
def hook_fn(self, module, input, output):
# hook co compute deepinversion's feature distribution regularization
nch = input[0].shape[1]
mean = input[0].mean([0, 2, 3])
var = input[0].permute(1, 0, 2, 3).contiguous().view(
[nch, -1]).var(1, unbiased=False)
# forcing mean and variance to match between two distributions
# other ways might work better, i.g. KL divergence
r_feature = torch.norm(module.running_var.data.detach() - var, 2) + torch.norm(
module.running_mean.data.detach() - mean, 2)
self.r_feature = r_feature
# must have no output
def close(self):
self.hook.remove()
class MyNet(nn.Module):
def __init__(self):
super(MyNet, self).__init__()
self.in_planes = 64
self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
self.bn1 = nn.BatchNorm2d(64)
self.conv2 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1, bias=False)
self.bn2 = nn.BatchNorm2d(64)
self.conv3 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1, bias=False)
self.bn3 = nn.BatchNorm2d(64)
# regist hooks
def forward(self, x):
x = self.conv1(x)
x = self.bn1(x)
x = self.conv2(x)
x = self.bn2(x)
x = self.conv3(x)
x = self.bn3(x)
return x
net = nn.DataParallel(MyNet().cuda())
loss_r_feature_layers = []
for module in net.modules():
if isinstance(module, nn.BatchNorm2d):
loss_r_feature_layers.append(
DeepInversionFeatureHook(module))
# net = MyNet().cuda()
for i in range(10):
print('========> %d iteration.' % i)
output = net(torch.randn(512, 3, 32, 32).cuda())
extra_loss = sum([mod.r_feature
for (idx, mod) in enumerate(loss_r_feature_layers)])
print('=> extra_loss:', extra_loss.size(), extra_loss.device, extra_loss)
print('=> extra_loss sum:', extra_loss.sum().size(), extra_loss.sum().device, extra_loss.sum())
# loss = F.mse_loss(output, torch.ones_like(output))
# print('=> mse loss:', loss.size(), loss.device, loss)
print('=> output:', output.size(), output.device, output.grad_fn)
loss = extra_loss.sum()
loss.backward()
net.zero_grad()
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment