Skip to content

Instantly share code, notes, and snippets.

@kyoto-cheng
Last active September 20, 2021 16:05
Show Gist options
  • Save kyoto-cheng/f0b6f39a8c9e27ca2e06d2956666d06d to your computer and use it in GitHub Desktop.
Save kyoto-cheng/f0b6f39a8c9e27ca2e06d2956666d06d to your computer and use it in GitHub Desktop.
# weights for each style layer
# weighting earlier layers more will result in *larger* style artifacts
# notice we are excluding `conv4_2` our content representation
style_weights = {'conv1_1': 1.,
'conv2_1': 0.75,
'conv3_1': 0.2,
'conv4_1': 0.2,
'conv5_1': 0.2}
content_weight = 1 # alpha
style_weight = 1e6 # beta
# for displaying the target image, intermittently
show_every = 1000
# iteration hyperparameters
optimizer = optim.Adam([target], lr=0.005)
steps = 20000 # decide how many iterations to update your image (5000)
for ii in range(1, steps+1):
# get the features from your target image
target_features = get_features(target, vgg)
# the content loss
content_loss = torch.mean((target_features['conv4_2'] - content_features['conv4_2'])**2)
# the style loss
# initialize the style loss to 0
style_loss = 0
# then add to it for each layer's gram matrix loss
for layer in style_weights:
# get the "target" style representation for the layer
target_feature = target_features[layer]
target_gram = gram_matrix(target_feature)
_, d, h, w = target_feature.shape
# get the "style" style representation
style_gram = style_grams[layer]
# the style loss for one layer, weighted appropriately
layer_style_loss = style_weights[layer] * torch.mean((target_gram - style_gram)**2)
# add to the style loss
style_loss += layer_style_loss / (d * h * w)
# calculate the *total* loss
total_loss = content_weight * content_loss + style_weight * style_loss
# update your target image
optimizer.zero_grad()
total_loss.backward()
optimizer.step()
# display intermediate images and print the loss
if ii % show_every == 0:
print('Total loss: ', total_loss.item())
plt.figure(figsize=(8, 12))
# plt.imshow(im_convert(target))
plt.savefig("Graph" + str(ii) +".png", format="PNG")
plt.clf()
# display content and final, target image
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(20, 10))
ax1.imshow(im_convert(content))
ax2.imshow(im_convert(target))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment