Skip to content

Instantly share code, notes, and snippets.

@LiamHz
Last active January 16, 2019 04:00
Show Gist options
  • Save LiamHz/150c84da6c234ec4aac16202fec4c5ef to your computer and use it in GitHub Desktop.
Save LiamHz/150c84da6c234ec4aac16202fec4c5ef to your computer and use it in GitHub Desktop.
show_every = 400 # Show target image every x steps
optimizer = optim.Adam([target], lr=0.003) # Optimizer hyperparameters
steps = 2000 # How many iterations to update content image
# Training Loop
for ii in range(steps):
# Calculate the content loss
target_features = get_features(target, vgg)
content_loss = torch.mean((target_features['conv4_2'] - content_features['conv4_2'])**2)
# Initialize style loss to 0
style_loss = 0
# Iterate through each style layer and add to the style loss
for layer in style_weights:
# Get the target (goal) style representation for the layer
target_feature = target_features[layer]
_, d, h, w = target_feature.shape
target_gram = gram_matrix(target_features[layer]) # Calculate the target gram matrix
style_gram = style_grams[layer] # Get the style representation
# Calculate the weighted style loss for one layer
layer_style_loss = style_weights[layer] * torch.mean((target_gram - style_gram)**2)
style_loss += layer_style_loss / (d * h * w) # Add to the style loss
# Calculate the total loss
total_loss = (content_loss * content_weight) + (style_loss * style_weight)
# Update target image
optimizer.zero_grad()
total_loss.backward()
optimizer.step()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment