Skip to content

Instantly share code, notes, and snippets.

@J3698
Created January 28, 2021 01:55
Show Gist options
  • Save J3698/da140be0d552eb6a616234474a962a15 to your computer and use it in GitHub Desktop.
Save J3698/da140be0d552eb6a616234474a962a15 to your computer and use it in GitHub Desktop.
get batch style transfer loss
def get_batch_style_transfer_loss(encoder, decoder, style_image, content_image):
style_features = encoder(style_image)
content_features = encoder(content_image)
stylized_image, stylized_features = create_stylized_image(decoder, content_features, style_features)
features_of_stylized = encoder(stylized_image)
style_loss = compute_style_loss(features_of_stylized, style_features)
content_loss = compute_content_loss(features_of_stylized, stylized_features)
return style_loss + content_loss
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment