Created
March 12, 2025 16:39
-
-
Save malfet/0e5e669df6c6764959e4af48da247fb7 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# How to reuse shared memory | |
# Right now MPS inductor produces following code | |
# #include <c10/metal/random.h> | |
# #include <c10/metal/special_math.h> | |
# #include <c10/metal/utils.h> | |
# #include <c10/metal/reduction_utils.h> | |
# kernel void generated_kernel( | |
# device float* out_ptr0, | |
# device float* out_ptr1, | |
# constant float* in_ptr0, | |
# uint2 thread_pos [[thread_position_in_grid]], | |
# uint2 group_pos [[thread_position_in_threadgroup]] | |
# ) { | |
# auto xindex = thread_pos.x; | |
# auto r0_index = thread_pos.y; | |
# int r0_0 = r0_index; | |
# threadgroup float tmp_acc_0[128]; | |
# threadgroup float tmp_acc_1[128]; | |
# auto tmp0 = in_ptr0[r0_0]; | |
# tmp_acc_0[r0_index] = static_cast<float>(tmp0); | |
# tmp_acc_1[r0_index] = static_cast<float>(tmp0); | |
# auto tmp1 = c10::metal::threadgroup_min(tmp_acc_0, 128); | |
# out_ptr0[0] = static_cast<float>(tmp1); | |
# auto tmp2 = c10::metal::threadgroup_max(tmp_acc_1, 128); | |
# out_ptr1[0] = static_cast<float>(tmp2); | |
# } | |
import torch | |
from torch._inductor.utils import run_and_get_code | |
if __name__ == "__main__": | |
result, code = torch._inductor.utils.run_and_get_code(torch.compile(lambda x: (x.min(), x.max())), torch.rand(128, device='mps')) | |
print(code[0]) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment