Last active
June 7, 2019 21:59
-
-
Save martinsotir/b51fc38e85cb728b1c187fc32c789e06 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# WIP | |
# Inspired from Keras and https://towardsdatascience.com/how-to-visualize-convolutional-features-in-40-lines-of-code-70b7d87b0030 | |
from pathlib import Path | |
import torch | |
import torchvision.utils as vutils | |
import matplotlib.pyplot as plt | |
from torchvision.utils import make_grid | |
import cv2 | |
import numpy as np | |
from torchvision import transforms | |
from PIL import Image | |
def total_variation_loss(x): # From: https://discuss.pytorch.org/t/yet-another-post-on-custom-loss-functions/14552 | |
B, C, H, W = x.size() | |
dh = torch.pow(x[:,:,1:,:] - x[:,:,:-1,:], 2).reshape(B, C, -1) | |
dv = torch.pow(x[:,:,:,1:] - x[:,:,:,:-1], 2).reshape(B, C, -1) | |
return torch.norm(torch.cat([dh, dv], dim=2), p=1) / (H*W) | |
class SaveFeatures(): | |
def __init__(self, module): | |
self.hook = module.register_forward_hook(self.hook_fn) | |
def hook_fn(self, module, input, output): | |
self.features = torch.tensor(output, requires_grad=True).cuda() | |
def close(self): | |
self.hook.remove() | |
def visualize_batch(model, module, nb_filters, lr=0.1, color=True, tv_reg=0.01, l2_reg=0.05, size=56, steps=10, max_pixel_value=1): | |
set_trainable(model, False) | |
activations = SaveFeatures(module) | |
img = torch.rand((nb_filters, 3 if color else 1, size, size), dtype=torch.float32, | |
requires_grad=True, device=torch.device('cuda')) | |
optimizer = torch.optim.Adam([img], lr=lr, weight_decay=0) | |
for n in range(steps): | |
optimizer.zero_grad() | |
model(img) | |
loss = -torch.stack([activations.features[i, i].mean() + | |
l2_reg * torch.norm((img[i]-max_pixel_value/2)/(max_pixel_value/2), p=2) | |
for i in range(nb_filters)]).mean() + tv_reg * total_variation_loss(img) | |
loss.backward() | |
optimizer.step() | |
return img.detach() | |
def plot_filter_ma(net, layer_name, layer_conv_name, lr=0.01, steps=100, color=color, tv_reg=0.01, l2_reg=0.01, size=80, max_pixel_value=1): | |
layer_module = dict(dict(dict(dict(net.named_children())['CCN'] | |
.named_children()))[layer_name] | |
.named_children())[layer_conv_name] | |
nb_filters = dict(layer_module.named_children())['conv'].out_channels | |
img = visualize_batch(net, layer_module, nb_filters, lr=lr, steps=steps, color=color, tv_reg=tv_reg, l2_reg=l2_reg, size=size, max_pixel_value=max_pixel_value) | |
plt.figure(figsize=(15, 15)) | |
plt.title(f"{layer_name}_{layer_conv_name}") | |
plt.imshow(np.moveaxis(make_grid(img, normalize=True).cpu().numpy(), 0, 2)[:, :, slice(None ,None, -1) if color else 0]) | |
plt.show(block=False) | |
net = ... | |
net = net.cuda().eval() | |
color = False | |
plot_filter_ma(net, 'base_layer', '0', lr=0.1, steps=300, color=color, tv_reg=0.0001, l2_reg=0.01, size=80, max_pixel_value=1) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment