Created
          August 10, 2020 00:22 
        
      - 
      
 - 
        
Save neelriyer/a3fed30a92317fd1e3ef38f87530cc25 to your computer and use it in GitHub Desktop.  
    Content loss style transfer in pytorch
  
        
  
    
      This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
      Learn more about bidirectional Unicode characters
    
  
  
    
  | import torch | |
| import torch.nn as nn | |
| # adapted from: https://ghamrouni.github.io/stn-tuto/advanced/neural_style_tutorial.html# | |
| class ContentLoss(nn.Module): | |
| def __init__(self, target, weight): | |
| super(ContentLoss, self).__init__() | |
| # we 'detach' the target content from the tree used | |
| self.target = target.detach() * weight | |
| # to dynamically compute the gradient: this is a stated value, | |
| # not a variable. Otherwise the forward method of the criterion | |
| # will throw an error. | |
| self.weight = weight | |
| self.criterion = nn.MSELoss() | |
| def forward(self, input): | |
| self.loss = self.criterion(input * self.weight, self.target) | |
| self.output = input | |
| return self.output | |
| def backward(self, retain_graph=True): | |
| self.loss.backward(retain_graph=retain_graph) | |
| return self.loss | 
  
    Sign up for free
    to join this conversation on GitHub.
    Already have an account?
    Sign in to comment