Skip to content

Instantly share code, notes, and snippets.

@J3698
Created July 1, 2021 04:32
Show Gist options
  • Save J3698/d8b19262da9854b5aa35cd1744f32ac4 to your computer and use it in GitHub Desktop.
Save J3698/d8b19262da9854b5aa35cd1744f32ac4 to your computer and use it in GitHub Desktop.
def get_style_transfer_loss(encoder, decoder, content_image, style_image, lambda_content, lambda_style):
assert_shape(content_image, (g_batch_size, 3, 256, 256))
style_features = encoder(style_image)
content_features = encoder(content_image)
stylized_images, stylized_features = create_stylized_images(decoder, content_features, style_features)
features_of_stylized = encoder(stylized_images)
style_loss = compute_style_loss(features_of_stylized, style_features)
content_loss = compute_content_loss(features_of_stylized[-1], stylized_features)
return style_loss * lambda_style, content_loss * lambda_content, stylized_images
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment