Skip to content

Instantly share code, notes, and snippets.

@vadimkantorov
Created January 18, 2017 16:04
Show Gist options
  • Save vadimkantorov/fd5cfd67ad358200d963571f4dd0ec52 to your computer and use it in GitHub Desktop.
Save vadimkantorov/fd5cfd67ad358200d963571f4dd0ec52 to your computer and use it in GitHub Desktop.
Pairwise L2 distances in Torch
-- A is a matrix n x d, B is a matrix m x d, i.e. data points are row vectors; returns a matrix n x m
function pdist(A, B)
local eps = 1e-6
local pdist2 = A * B:t()
local normA = torch.sum(torch.cmul(A, A), 2):view(pdist2:size(1), 1):expandAs(pdist2)
local normB = torch.sum(torch.cmul(B, B), 2):t():view(1, pdist2:size(2)):expandAs(pdist2)
return torch.sqrt(normA - 2.0 * pdist2 + normB + eps)
end
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment