Last active
July 17, 2020 20:22
-
-
Save danking/5972bf5dba62acab7ef9c2c38af9893b to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import hail as hl | |
hl.nd.array([1, 2, 3, 4]).reshape((2, 2)).show() | |
# FIXME: use ndarray sum / fma | |
def block_product(left, right): | |
product = left @ right | |
n_rows, n_cols = product.shape | |
return hl.struct( | |
shape=product.shape, | |
block=hl.range(hl.int(n_rows * n_cols)).map( | |
lambda absolute: product[absolute % n_rows, absolute // n_rows])) | |
def block_aggregate(prod): | |
shape = prod.shape | |
block = prod.block | |
return hl.nd.from_column_major( | |
hl.agg.array_sum(block), | |
hl.agg.take(shape, 1)[0]) | |
x = hl.nd.array([1, 2, 3, 4]).reshape((2, 2)) | |
y = hl.nd.array([1, 0, 0, 1]).reshape((2,2)) | |
x.collect() | |
y.collect() | |
block_product(x, y).collect() | |
t = hl.utils.range_table(3) | |
t = t.annotate(block = x) | |
t.collect() | |
t = t.annotate(product = block_product(t.block, y)) | |
t.product.collect() | |
t.aggregate(hl.agg.array_sum(t.product.block)) | |
thing = t.aggregate( | |
hl.struct( | |
the_sum = hl.agg.array_sum(t.product.block), | |
the_shape = hl.agg.take(t.product.shape, 1)[0] | |
) | |
) | |
thing | |
hl.nd.from_column_major(thing.the_sum, thing.the_shape).collect() | |
thing = t.aggregate( | |
hl.nd.from_column_major( | |
hl.agg.array_sum(t.product.block), | |
hl.agg.take(t.product.shape, 1)[0] | |
) | |
) | |
thing | |
def to_column_major(ndarray): | |
n_rows, n_cols = ndarray.shape | |
return hl.range(hl.int(n_rows * n_cols)).map( | |
lambda absolute: ndarray[absolute % n_rows, absolute // n_rows]) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment