Last active
January 16, 2019 04:00
-
-
Save LiamHz/150c84da6c234ec4aac16202fec4c5ef to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
show_every = 400 # Show target image every x steps | |
optimizer = optim.Adam([target], lr=0.003) # Optimizer hyperparameters | |
steps = 2000 # How many iterations to update content image | |
# Training Loop | |
for ii in range(steps): | |
# Calculate the content loss | |
target_features = get_features(target, vgg) | |
content_loss = torch.mean((target_features['conv4_2'] - content_features['conv4_2'])**2) | |
# Initialize style loss to 0 | |
style_loss = 0 | |
# Iterate through each style layer and add to the style loss | |
for layer in style_weights: | |
# Get the target (goal) style representation for the layer | |
target_feature = target_features[layer] | |
_, d, h, w = target_feature.shape | |
target_gram = gram_matrix(target_features[layer]) # Calculate the target gram matrix | |
style_gram = style_grams[layer] # Get the style representation | |
# Calculate the weighted style loss for one layer | |
layer_style_loss = style_weights[layer] * torch.mean((target_gram - style_gram)**2) | |
style_loss += layer_style_loss / (d * h * w) # Add to the style loss | |
# Calculate the total loss | |
total_loss = (content_loss * content_weight) + (style_loss * style_weight) | |
# Update target image | |
optimizer.zero_grad() | |
total_loss.backward() | |
optimizer.step() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment