Skip to content

Instantly share code, notes, and snippets.

@LiamHz
Created January 8, 2019 22:04
Show Gist options
  • Save LiamHz/002b30aa7e3fb34d50ae2c22436cce47 to your computer and use it in GitHub Desktop.
Save LiamHz/002b30aa7e3fb34d50ae2c22436cce47 to your computer and use it in GitHub Desktop.
def gram_matrix(tensor):
# Get the batch_size, depth, height, and width of the Tensor
# Reshape it, so we're multiplying the features for each channel
_, d, h, w = tensor.size()
tensor = tensor.view(d, h*w)
# Calculate the gram matrix by multiplying the tensor with its transpose
gram = torch.mm(tensor, tensor.t())
return gram
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment