Skip to content

Instantly share code, notes, and snippets.

@malfet
Last active February 12, 2025 18:16
Show Gist options
  • Save malfet/8610f5eadd6916e3203a695352ae6bd9 to your computer and use it in GitHub Desktop.
Save malfet/8610f5eadd6916e3203a695352ae6bd9 to your computer and use it in GitHub Desktop.
let shader_source = """
struct add_functor {
template <typename T>
inline T operator()(const T a, const T b) {
return static_cast<T>(a + b);
}
};
namespace {
struct sub_functor {
template <typename T>
inline T operator()(const T a, const T b) {
return static_cast<T>(a - b);
}
};
} // anonymous namespace
template <typename T, typename F>
kernel void binary_executor(
constant T* input [[buffer(0)]],
constant T* other [[buffer(1)]],
device T* out [[buffer(2)]],
uint tid [[thread_position_in_grid]]) {
F f;
out[tid] = f(input[tid], other[tid]);
}
template
[[host_name("add_float")]] kernel void binary_executor<float, add_functor>(constant float*, constant float *, device float*, uint);
template
[[host_name("sub_float")]] kernel void binary_executor<float, sub_functor>(constant float*, constant float *, device float*, uint);
"""
import Metal
guard let device = MTLCopyAllDevices().first else { fatalError("Not Metal device found") }
let library = try! device.makeLibrary(source:shader_source, options:MTLCompileOptions())
// Expect two kernels to be printed, but see only one, with functor in global namespace
for kernel_name in library.functionNames {
print(kernel_name)
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment