Created
June 10, 2022 20:21
-
-
Save Lunderberg/7dcd4edbdd7bedfb08072037792aa585 to your computer and use it in GitHub Desktop.
Examinging the hoisting of cached transforms from TIR into graph level
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Initial end-to-end model | |
@script.ir_module | |
class EndToEndModel: | |
@R.func | |
def main(x: R.Tensor[16]): | |
F = R.const(shape=[3]) | |
Y = call_tir(conv1d_16, X, F) | |
Z = call_tir(conv1d_18, Y, F) | |
return Z | |
@T.prim_func | |
def conv1d_16( | |
X: T.Buffer[(16,), "float32"], | |
F: T.Buffer[(3,), "float32"], | |
Y: T.Buffer[(18,), "float32"], | |
): | |
for Yi in T.serial(18): | |
Y[Yi] = 0.0 | |
for fi in T.serial(3): | |
Xi = Yi - fi + 2 | |
if 0 <= Xi < 16: | |
Y[Yi] = Y[Yi] + F[fi] * X[Xi] | |
@T.prim_func | |
def conv1d_18( | |
Y: T.Buffer[(18,), "float32"], | |
F: T.Buffer[(3,), "float32"], | |
Z: T.Buffer[(20,), "float32"], | |
): | |
for Zi in T.serial(20): | |
Z[Zi] = 0.0 | |
for fi in T.serial(3): | |
Yi = Zi - fi + 2 | |
if 0 <= Yi < 18: | |
Z[Zi] = Z[Zi] + F[fi] * Y[Yi] | |
# After applying the same simplifications as proposed in the RFC, but | |
# with a cache_read/cache_write stage that contains the transformed | |
# buffers, rather than treating the input argument as transformed. | |
@script.ir_module | |
class EndToEndModel: | |
@R.func | |
def main(x: R.Tensor[16]): | |
F = R.const(shape=[3]) | |
y = call_tir(conv1d_16, x, F) | |
z = call_tir(conv1d_18, y, F) | |
return z | |
# X_read_cache = sched.cache_read(X) | |
# Y_write_cache = sched.cache_write(Y) | |
# sched.transform_layout(X_read_cache, index_map = lambda i: [(i+2)//8, (i+2)%8], pad_value = 0.0) | |
# sched.transform_layout(Y_write_cache, index_map = lambda i: [(i+2)//8, (i+2)%8], pad_value = 0.0) | |
@T.prim_func | |
def conv1d_16( | |
X: T.Buffer[(16,), "float32"], | |
F: T.Buffer[(3,), "float32"], | |
Y: T.Buffer[(18,), "float32"], | |
): | |
X_read_cache = T.alloc_buffer([3, 8], "float32") | |
for io, ii in T.grid(3, 8): | |
i = 8 * io + ii - 2 | |
if 0 <= i < 16: | |
X_read_cache[io, ii] = X[i] | |
else: | |
X_read_cache[io, ii] = 0.0 | |
Y_write_cache = T.alloc_buffer([3, 8], "float32") | |
for io, ii in T.serial(3, 8): | |
Y_write_cache[io, ii] = 0.0 | |
for io in T.serial(3): | |
for ii in T.serial(8): | |
for fi in T.serial(3): | |
Y_write_cache[(io + (ii + fi) // 8) % 3, (ii + fi) % 8] = ( | |
Y_write_cache[(io + (ii + fi) // 8) % 3, (ii + fi) % 8] | |
+ F[fi] * X_read_cache[io, ii] | |
) | |
for io, ii in T.grid(3, 8): | |
i = 8 * io + ii - 2 | |
if 0 <= i < 18: | |
Y[i] = Y_write_cache[io, ii] | |
# Y_read_cache = sched.cache_read(Y) | |
# Z_write_cache = sched.cache_write(Z) | |
# sched.transform_layout(Y_read_cache, index_map = lambda i: [(i+2)//8, (i+2)%8], pad_value = 0.0) | |
# sched.transform_layout(Z_write_cache, index_map = lambda i: [(i+2)//8, (i+2)%8], pad_value = 0.0) | |
@T.prim_func | |
def conv1d_18( | |
Y: T.Buffer[(18,), "float32"], | |
F: T.Buffer[(3,), "float32"], | |
Z: T.Buffer[(20,), "float32"], | |
): | |
Y_read_cache = T.alloc_buffer([3, 8], "float32") | |
for io, ii in T.grid(3, 8): | |
i = 8 * io + ii - 2 | |
if 0 <= i < 18: | |
Y_read_cache[io, ii] = Y[i] | |
else: | |
Y_read_cache[io, ii] = 0.0 | |
Z_write_cache = T.alloc_buffer([3, 8], "float32") | |
for io, ii in T.serial(3, 8): | |
Z_write_cache[io, ii] = 0.0 | |
for io in T.serial(3): | |
for ii in T.serial(8): | |
for fi in T.serial(3): | |
Z_write_cache[(io + (ii + fi) // 8) % 3, (ii + fi) % 8] = ( | |
Z_write_cache[(io + (ii + fi) // 8) % 3, (ii + fi) % 8] | |
+ F[fi] * Y_read_cache[io, ii] | |
) | |
for io, ii in T.grid(3, 8): | |
i = 8 * io + ii - 2 | |
if 0 <= i < 20: | |
Z[i] = Z_write_cache[io, ii] | |
# Hoist out layout transformations into independent functions | |
@script.ir_module | |
class EndToEndModel: | |
@R.func | |
def main(x: R.Tensor[16]): | |
F = R.const(shape=[3]) | |
X_read_cache = call_tir(transform_X, X) | |
Y_write_cache = call_tir(conv1d_16, X_read_cache, F) | |
Y = call_tir(inv_transform_Y, Y_cache) | |
Y_read_cache = call_tir(transform_Y, Y) | |
Z_write_cache = call_tir(conv1d_18, Y_read_cache, F) | |
Z = call_tir(inv_transform_Z, Z_write_cache) | |
return Z | |
@T.prim_func | |
def transform_X( | |
X: T.Buffer[(16,), "float32"], | |
X_read_cache: T.Buffer[(3, 8), "float32"], | |
): | |
for io, ii in T.grid(3, 8): | |
i = 8 * io + ii - 2 | |
if 0 <= i < 16: | |
X_read_cache[io, ii] = X[i] | |
else: | |
X_read_cache[io, ii] = 0.0 | |
@T.prim_func | |
def conv1d_16( | |
X_read_cache: T.Buffer[(16,), "float32"], | |
F: T.Buffer[(3,), "float32"], | |
Y_write_cache: T.Buffer[(3, 8), "float32"], | |
): | |
for io, ii in T.serial(3, 8): | |
Y_write_cache[io, ii] = 0.0 | |
for io in T.serial(3): | |
for ii in T.serial(8): | |
for fi in T.serial(3): | |
Y_write_cache[(io + (ii + fi) // 8) % 3, (ii + fi) % 8] = ( | |
Y_write_cache[(io + (ii + fi) // 8) % 3, (ii + fi) % 8] | |
+ F[fi] * X_read_cache[io, ii] | |
) | |
@T.prim_func | |
def inv_transform_Y( | |
Y_write_cache: T.Buffer[(3, 8), "float32"], | |
Y: T.Buffer[(18,), "float32"], | |
): | |
for io, ii in T.grid(3, 8): | |
i = 8 * io + ii - 2 | |
if 0 <= i < 18: | |
Y[i] = Y_write_cache[io, ii] | |
@T.prim_func | |
def transform_Y( | |
Y: T.Buffer[(18,), "float32"], | |
Y_read_cache: T.Buffer[(3, 8), "float32"], | |
): | |
for io, ii in T.grid(3, 8): | |
i = 8 * io + ii - 2 | |
if 0 <= i < 18: | |
Y_read_cache[io, ii] = Y[i] | |
else: | |
Y_read_cache[io, ii] = 0.0 | |
@T.prim_func | |
def conv1d_18( | |
Y_read_cache: T.Buffer[(3, 8), "float32"], | |
F: T.Buffer[(3,), "float32"], | |
Z_write_cache: T.Buffer[(3, 8), "float32"], | |
): | |
for io, ii in T.serial(3, 8): | |
Z_write_cache[io, ii] = 0.0 | |
for io in T.serial(3): | |
for ii in T.serial(8): | |
for fi in T.serial(3): | |
Z_write_cache[(io + (ii + fi) // 8) % 3, (ii + fi) % 8] = ( | |
Z_write_cache[(io + (ii + fi) // 8) % 3, (ii + fi) % 8] | |
+ F[fi] * Y_read_cache[io, ii] | |
) | |
@T.prim_func | |
def inv_transform_Z( | |
Z_write_cache: T.Buffer[(3, 8), "float32"], | |
Z: T.Buffer[(20,), "float32"], | |
): | |
for io, ii in T.grid(3, 8): | |
i = 8 * io + ii - 2 | |
if 0 <= i < 20: | |
Z[i] = Z_write_cache[io, ii] | |
# Merging the calls to inv_transform_Y and transform_Y | |
@script.ir_module | |
class EndToEndModel: | |
@R.func | |
def main(x: R.Tensor[16]): | |
F = R.const(shape=[3]) | |
X_read_cache = call_tir(transform_X, X) | |
Y_write_cache = call_tir(conv1d_16, X_read_cache, F) | |
Y_read_cache = call_tir(fused_inv_transform_Y_transform_Y, Y_write_cache) | |
Y_read_cache = call_tir(transform_Y, Y) | |
Z_write_cache = call_tir(conv1d_18, Y_read_cache, F) | |
Z = call_tir(inv_transform_Z, Z_write_cache) | |
return Z | |
@T.prim_func | |
def transform_X( | |
X: T.Buffer[(16,), "float32"], | |
X_read_cache: T.Buffer[(3, 8), "float32"], | |
): | |
for io, ii in T.grid(3, 8): | |
i = 8 * io + ii - 2 | |
if 0 <= i < 16: | |
X_read_cache[io, ii] = X[i] | |
else: | |
X_read_cache[io, ii] = 0.0 | |
@T.prim_func | |
def conv1d_16( | |
X_read_cache: T.Buffer[(16,), "float32"], | |
F: T.Buffer[(3,), "float32"], | |
Y_write_cache: T.Buffer[(3, 8), "float32"], | |
): | |
for io, ii in T.serial(3, 8): | |
Y_write_cache[io, ii] = 0.0 | |
for io in T.serial(3): | |
for ii in T.serial(8): | |
for fi in T.serial(3): | |
Y_write_cache[(io + (ii + fi) // 8) % 3, (ii + fi) % 8] = ( | |
Y_write_cache[(io + (ii + fi) // 8) % 3, (ii + fi) % 8] | |
+ F[fi] * X_read_cache[io, ii] | |
) | |
for io in T.serial(3): | |
for ii in T.serial(8): | |
i = 8 * io + ii - 2 | |
if not (0 <= i < 18): | |
Y_write_cache[io, ii] = 0.0 | |
@T.prim_func | |
def fused_inv_transform_Y_transform_Y( | |
Y_write_cache: T.Buffer[(3, 8), "float32"], | |
Y_read_cache: T.Buffer[(3, 8), "float32"], | |
): | |
Y = (T.alloc_buffer[(18,), "float32"],) | |
for io, ii in T.grid(3, 8): | |
i = 8 * io + ii - 2 | |
if 0 <= i < 18: | |
Y[i] = Y_write_cache[io, ii] | |
for io, ii in T.grid(3, 8): | |
i = 8 * io + ii - 2 | |
if 0 <= i < 18: | |
Y_read_cache[io, ii] = Y[i] | |
else: | |
Y_read_cache[io, ii] = 0.0 | |
@T.prim_func | |
def conv1d_18( | |
Y_read_cache: T.Buffer[(3, 8), "float32"], | |
F: T.Buffer[(3,), "float32"], | |
Z_write_cache: T.Buffer[(3, 8), "float32"], | |
): | |
for io, ii in T.serial(3, 8): | |
Z_write_cache[io, ii] = 0.0 | |
for io in T.serial(3): | |
for ii in T.serial(8): | |
for fi in T.serial(3): | |
Z_write_cache[(io + (ii + fi) // 8) % 3, (ii + fi) % 8] = ( | |
Z_write_cache[(io + (ii + fi) // 8) % 3, (ii + fi) % 8] | |
+ F[fi] * Y_read_cache[io, ii] | |
) | |
@T.prim_func | |
def inv_transform_Z( | |
Z_write_cache: T.Buffer[(3, 8), "float32"], | |
Z: T.Buffer[(20,), "float32"], | |
): | |
for io, ii in T.grid(3, 8): | |
i = 8 * io + ii - 2 | |
if 0 <= i < 20: | |
Z[i] = Z_write_cache[io, ii] | |
# Same as previous, but only considering this function. If we can | |
# prove that this is equivalent to a memcopy, then we are justified in | |
# removing it from main(), and replacing all use of `Y_read_cache` | |
# with `Y_write_cache`. | |
@T.prim_func | |
def fused_inv_transform_Y_transform_Y( | |
Y_write_cache: T.Buffer[(3, 8), "float32"], | |
Y_read_cache: T.Buffer[(3, 8), "float32"], | |
): | |
Y = (T.alloc_buffer[(18,), "float32"],) | |
for io, ii in T.grid(3, 8): | |
i = 8 * io + ii - 2 | |
if 0 <= i < 18: | |
Y[i] = Y_write_cache[io, ii] | |
for io, ii in T.grid(3, 8): | |
i = 8 * io + ii - 2 | |
if 0 <= i < 18: | |
Y_read_cache[io, ii] = Y[i] | |
else: | |
Y_read_cache[io, ii] = 0.0 | |
# After inlining Y[i]. In order to prove that this is equivalent to a | |
# memcpy, we would need to know a priori that when we are outside of | |
# the bounds of `0 <= i < 18`, `Y_write_cache[io, ii]` contains the | |
# value 0.0. This is not something that could be determined from any | |
# local analysis of this function, and would require reconstructing | |
# the buffer constraint based on analysis of other functions. | |
@T.prim_func | |
def fused_inv_transform_Y_transform_Y( | |
Y_write_cache: T.Buffer[(3, 8), "float32"], | |
Y_read_cache: T.Buffer[(3, 8), "float32"], | |
): | |
for io, ii in T.grid(3, 8): | |
i = 8 * io + ii - 2 | |
if 0 <= i < 18: | |
Y_read_cache[io, ii] = Y_write_cache[io, ii] | |
else: | |
Y_read_cache[io, ii] = 0.0 |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment