Skip to content

Instantly share code, notes, and snippets.

@Steboss89
Created October 21, 2022 13:52
Show Gist options
  • Save Steboss89/877a25641e077b5156534c101fa50878 to your computer and use it in GitHub Desktop.
Save Steboss89/877a25641e077b5156534c101fa50878 to your computer and use it in GitHub Desktop.
Main function for computing matmul with Strassen in JAX
def f(a: BlockMatrix, b: BlockMatrix) -> BlockMatrix:
"""Multiplies block matrices `a` and `b`."""
n = len(a)
result = [[None] * n for _ in range(n)]
for alpha in range(rank):
left = None
for i in range(n):
for j in range(n):
if factors[0][i, j, alpha] != 0:
curr = factors[0][i, j, alpha] * a[i][j]
if left is None:
left = curr
else:
left += curr
right = None
for j in range(n):
for k in range(n):
if factors[1][j, k, alpha] != 0:
curr = factors[1][j, k, alpha] * b[j][k]
if right is None:
right = curr
else:
right += curr
matrix_product = left @ right
for i in range(n):
for k in range(n):
if factors[2][i, k, alpha] != 0:
curr = factors[2][i, k, alpha] * matrix_product
if result[i][k] is None:
result[i][k] = curr
else:
result[i][k] += curr
return result
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment