Skip to content

Instantly share code, notes, and snippets.

@neelriyer
Created August 10, 2020 00:23
Show Gist options
  • Save neelriyer/2b39946d6fa606406cb207e110fd24c1 to your computer and use it in GitHub Desktop.
Save neelriyer/2b39946d6fa606406cb207e110fd24c1 to your computer and use it in GitHub Desktop.
Style Loss style transfer in Pytorch
# adapted from: https://github.com/alishdipani/Neural-Style-Transfer-Audio/blob/master/NeuralStyleTransfer.py
import torch
import torch.nn as nn
class GramMatrix(nn.Module):
def forward(self, input):
a, b, c = input.size() # a=batch size(=1)
# b=number of feature maps
# (c,d)=dimensions of a f. map (N=c*d)
features = input.view(a * b, c) # resise F_XL into \hat F_XL
G = torch.mm(features, features.t()) # compute the gram product
# we 'normalize' the values of the gram matrix
# by dividing by the number of element in each feature maps.
return G.div(a * b * c)
class StyleLoss(nn.Module):
def __init__(self, target, weight):
super(StyleLoss, self).__init__()
self.target = target.detach() * weight
self.weight = weight
self.gram = GramMatrix()
self.criterion = nn.MSELoss()
def forward(self, input):
self.output = input.clone()
self.G = self.gram(input)
self.G.mul_(self.weight)
self.loss = self.criterion(self.G, self.target)
return self.output
def backward(self,retain_graph=True):
self.loss.backward(retain_graph=retain_graph)
return self.loss
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment