Skip to content

Instantly share code, notes, and snippets.

@janpfeifer
janpfeifer / ops_util.go
Last active February 19, 2025 06:41
Safe cosine similarity
// CosineSimilarity ....
// lhs -> left-hand side
// rhs -> right-hand side
func CosineSimilarity(lhs, rhs *Node) *Node {
g := lhs.Graph()
dtype := lhs.DType()
axis := -1 // Axis over which to calculate the cosine.
// Mask for rows that are fully zero, for which cosine similary is not normally defined.
lhsMask := ReduceAndKeep(IsZero(lhs), ReduceLogicalAnd, axis),