Created
January 8, 2019 22:04
-
-
Save LiamHz/002b30aa7e3fb34d50ae2c22436cce47 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def gram_matrix(tensor): | |
# Get the batch_size, depth, height, and width of the Tensor | |
# Reshape it, so we're multiplying the features for each channel | |
_, d, h, w = tensor.size() | |
tensor = tensor.view(d, h*w) | |
# Calculate the gram matrix by multiplying the tensor with its transpose | |
gram = torch.mm(tensor, tensor.t()) | |
return gram |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment