Skip to content

Instantly share code, notes, and snippets.

@malfet
Created March 12, 2025 16:39
Show Gist options
  • Save malfet/0e5e669df6c6764959e4af48da247fb7 to your computer and use it in GitHub Desktop.
Save malfet/0e5e669df6c6764959e4af48da247fb7 to your computer and use it in GitHub Desktop.
# How to reuse shared memory
# Right now MPS inductor produces following code
# #include <c10/metal/random.h>
# #include <c10/metal/special_math.h>
# #include <c10/metal/utils.h>
# #include <c10/metal/reduction_utils.h>
# kernel void generated_kernel(
# device float* out_ptr0,
# device float* out_ptr1,
# constant float* in_ptr0,
# uint2 thread_pos [[thread_position_in_grid]],
# uint2 group_pos [[thread_position_in_threadgroup]]
# ) {
# auto xindex = thread_pos.x;
# auto r0_index = thread_pos.y;
# int r0_0 = r0_index;
# threadgroup float tmp_acc_0[128];
# threadgroup float tmp_acc_1[128];
# auto tmp0 = in_ptr0[r0_0];
# tmp_acc_0[r0_index] = static_cast<float>(tmp0);
# tmp_acc_1[r0_index] = static_cast<float>(tmp0);
# auto tmp1 = c10::metal::threadgroup_min(tmp_acc_0, 128);
# out_ptr0[0] = static_cast<float>(tmp1);
# auto tmp2 = c10::metal::threadgroup_max(tmp_acc_1, 128);
# out_ptr1[0] = static_cast<float>(tmp2);
# }
import torch
from torch._inductor.utils import run_and_get_code
if __name__ == "__main__":
result, code = torch._inductor.utils.run_and_get_code(torch.compile(lambda x: (x.min(), x.max())), torch.rand(128, device='mps'))
print(code[0])
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment