Created
November 10, 2022 04:56
-
-
Save ailzhang/d207a53c26720cccb7e7c5a9687009c2 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import taichi as ti | |
ti.init(ti.vulkan) | |
tp_ivec3 = ti.types.vector(3, ti.i32) | |
tp_ivec2 = ti.types.vector(2, ti.i32) | |
tp_vec3 = ti.types.vector(3, ti.f32) | |
x = ti.ndarray(ti.f32, shape=(12, 13)) | |
y = ti.ndarray(tp_ivec3, shape=(12,4)) | |
z = ti.ndarray(ti.i32, shape=(12, 12)) | |
m = ti.ndarray(tp_ivec2, shape=(2, 3)) | |
n = ti.ndarray(ti.f32, shape=(12,)) | |
k = ti.ndarray(tp_vec3, shape=(12,)) | |
@ti.kernel | |
def test1(arr: ti.types.ndarray()): | |
for I in ti.grouped(arr): | |
arr[I] = 1 | |
test1(x) | |
print(x.to_numpy()) | |
# The following currently errors out. Shall we automatically expand? | |
#test1(y) | |
#print(y.to_numpy()) | |
@ti.kernel | |
def test2(arr: ti.types.ndarray(dtype=ti.f32)): | |
for I in ti.grouped(arr): | |
arr[I] = 2.0 | |
test2(x) | |
print(x.to_numpy()) | |
# This works, but should it error out since it doesn't match type annotation? | |
test2(z) | |
print(z.to_numpy()) | |
# And this should definitely error out, but currently error message is bad | |
# test2(y) | |
@ti.kernel | |
def test3(arr: ti.types.ndarray(dtype=tp_ivec3)): | |
for I in ti.grouped(arr): | |
arr[I] = [0, 1, 2] | |
test3(y) | |
print(y.to_numpy()) | |
# The following errors out as expected, but message can be improved. | |
# test3(m) | |
# This should have errored out but didn't | |
# test3(k) | |
@ti.kernel | |
def test4(arr: ti.types.ndarray(dtype=tp_ivec3)): | |
for I in ti.grouped(arr): | |
arr[I] = [1, 2] | |
# test4 is just bad kernel. This errors out in kernel compilation as expected, maybe better error message? | |
# test4(y) | |
@ti.kernel | |
def test5(arr: ti.types.ndarray(field_dim=1)): | |
for i, j in arr: | |
arr[i, j] = 0 | |
# test5 is just bad kernel, this should error out but error message is weird | |
# test5(x) | |
@ti.kernel | |
def test6(arr: ti.types.ndarray(field_dim=2)): | |
for i, j in arr: | |
arr[i, j] = 6 | |
# Works as expected | |
test6(x) | |
print(x.to_numpy()) | |
# Errors out but message can be more helpful | |
# test6(m) | |
# Errors out but message is weird | |
# test6(n) | |
# Similar behavior should happen in AOT compilation as well. | |
m = ti.aot.Module(ti.vulkan) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment