Skip to content

Instantly share code, notes, and snippets.

@wcneill
Created July 12, 2020 19:51
Show Gist options
  • Save wcneill/192484db1c4bbb4d3bd2654f30de3d0e to your computer and use it in GitHub Desktop.
Save wcneill/192484db1c4bbb4d3bd2654f30de3d0e to your computer and use it in GitHub Desktop.
style loss
def style_loss(s_grams, t_features, weights):
"""
Compute style loss, i.e. the weighted sum of MSE of all layers.
"""
# for each style feature, get target and style gramians, compare
loss = 0
for layer in weights:
_, d, h, w = s_features[layer].shape
t_gram = gramian(t_features[layer])
layer_loss = torch.mean((t_gram - s_grams[layer]) ** 2) / (d * h * w)
loss += layer_loss * weights[layer]
return loss
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment