Last active
September 20, 2021 16:05
-
-
Save kyoto-cheng/f0b6f39a8c9e27ca2e06d2956666d06d to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# weights for each style layer | |
# weighting earlier layers more will result in *larger* style artifacts | |
# notice we are excluding `conv4_2` our content representation | |
style_weights = {'conv1_1': 1., | |
'conv2_1': 0.75, | |
'conv3_1': 0.2, | |
'conv4_1': 0.2, | |
'conv5_1': 0.2} | |
content_weight = 1 # alpha | |
style_weight = 1e6 # beta | |
# for displaying the target image, intermittently | |
show_every = 1000 | |
# iteration hyperparameters | |
optimizer = optim.Adam([target], lr=0.005) | |
steps = 20000 # decide how many iterations to update your image (5000) | |
for ii in range(1, steps+1): | |
# get the features from your target image | |
target_features = get_features(target, vgg) | |
# the content loss | |
content_loss = torch.mean((target_features['conv4_2'] - content_features['conv4_2'])**2) | |
# the style loss | |
# initialize the style loss to 0 | |
style_loss = 0 | |
# then add to it for each layer's gram matrix loss | |
for layer in style_weights: | |
# get the "target" style representation for the layer | |
target_feature = target_features[layer] | |
target_gram = gram_matrix(target_feature) | |
_, d, h, w = target_feature.shape | |
# get the "style" style representation | |
style_gram = style_grams[layer] | |
# the style loss for one layer, weighted appropriately | |
layer_style_loss = style_weights[layer] * torch.mean((target_gram - style_gram)**2) | |
# add to the style loss | |
style_loss += layer_style_loss / (d * h * w) | |
# calculate the *total* loss | |
total_loss = content_weight * content_loss + style_weight * style_loss | |
# update your target image | |
optimizer.zero_grad() | |
total_loss.backward() | |
optimizer.step() | |
# display intermediate images and print the loss | |
if ii % show_every == 0: | |
print('Total loss: ', total_loss.item()) | |
plt.figure(figsize=(8, 12)) | |
# plt.imshow(im_convert(target)) | |
plt.savefig("Graph" + str(ii) +".png", format="PNG") | |
plt.clf() | |
# display content and final, target image | |
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(20, 10)) | |
ax1.imshow(im_convert(content)) | |
ax2.imshow(im_convert(target)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment