Created
January 28, 2021 01:55
-
-
Save J3698/da140be0d552eb6a616234474a962a15 to your computer and use it in GitHub Desktop.
get batch style transfer loss
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def get_batch_style_transfer_loss(encoder, decoder, style_image, content_image): | |
style_features = encoder(style_image) | |
content_features = encoder(content_image) | |
stylized_image, stylized_features = create_stylized_image(decoder, content_features, style_features) | |
features_of_stylized = encoder(stylized_image) | |
style_loss = compute_style_loss(features_of_stylized, style_features) | |
content_loss = compute_content_loss(features_of_stylized, stylized_features) | |
return style_loss + content_loss |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment