Skip to content

Instantly share code, notes, and snippets.

@neelriyer
Created August 10, 2020 00:22
Show Gist options
  • Save neelriyer/a3fed30a92317fd1e3ef38f87530cc25 to your computer and use it in GitHub Desktop.
Save neelriyer/a3fed30a92317fd1e3ef38f87530cc25 to your computer and use it in GitHub Desktop.
Content loss style transfer in pytorch
import torch
import torch.nn as nn
# adapted from: https://ghamrouni.github.io/stn-tuto/advanced/neural_style_tutorial.html#
class ContentLoss(nn.Module):
def __init__(self, target, weight):
super(ContentLoss, self).__init__()
# we 'detach' the target content from the tree used
self.target = target.detach() * weight
# to dynamically compute the gradient: this is a stated value,
# not a variable. Otherwise the forward method of the criterion
# will throw an error.
self.weight = weight
self.criterion = nn.MSELoss()
def forward(self, input):
self.loss = self.criterion(input * self.weight, self.target)
self.output = input
return self.output
def backward(self, retain_graph=True):
self.loss.backward(retain_graph=retain_graph)
return self.loss
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment