Skip to content

Instantly share code, notes, and snippets.

@athas
Created July 14, 2020 13:41
Show Gist options
  • Select an option

  • Save athas/64c0bbb1f412deb10312b677e801c5d7 to your computer and use it in GitHub Desktop.

Select an option

Save athas/64c0bbb1f412deb10312b677e801c5d7 to your computer and use it in GitHub Desktop.
let dotprod xs ys = f32.sum (map2 (*) xs ys)
let all_pairs_norm [n][k] (A: [n][k]f32) (B: [n][k]f32) : [n][n]f32 =
let sqrA = map (map (**2)) A |> map f32.sum |> replicate n
let sqrB = map (map (**2)) B |> map f32.sum |> map (replicate n)
let diff = map (\a -> map (dotprod a) B) A
in map2 (map2 (+)) sqrA sqrB
|> map2 (map2 (\diff' x -> x-2*diff')) diff
|> map (map f32.sqrt)
entry whatever_jax_does A =
all_pairs_norm A A |> map (map (<=4.5))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment