Skip to content

Instantly share code, notes, and snippets.

@philipturner
Created June 29, 2023 21:24
Show Gist options
  • Select an option

  • Save philipturner/83241987a3abf291bd3fbd1aa090b99c to your computer and use it in GitHub Desktop.

Select an option

Save philipturner/83241987a3abf291bd3fbd1aa090b99c to your computer and use it in GitHub Desktop.
//
// main.swift
// KMeans
//
// Created by Philip Turner on 6/29/23.
//
import Metal
example_smawk()
func example_smawk() {
let n = 4096
let bufferSize = 16384
var cols: [Int32] = .init(repeating: 0, count: bufferSize)
var D: [Float] = .init(repeating: 0, count: bufferSize)
var cumsum: [Float] = .init(repeating: 0, count: bufferSize)
var cumsum2: [Float] = .init(repeating: 0, count: bufferSize)
var result: [Int32] = .init(repeating: 0, count: bufferSize)
print("'smawk' started.")
cols.withUnsafeMutableBufferPointer { cols in
D.withUnsafeMutableBufferPointer { D in
cumsum.withUnsafeMutableBufferPointer { cumsum in
cumsum2.withUnsafeMutableBufferPointer { cumsum2 in
result.withUnsafeMutableBufferPointer { result in
metal_smawk(n: Int32(n), cols: cols, D: D, cumsum: cumsum, cumsum2: cumsum2, result: result)
}
}
}
}
}
print("'smawk' finished.")
}
// Overwrites the memory of all the pointers you pass in, so don't enter
// anything you expect to stay constant.
func metal_smawk(
n: Int32,
cols: UnsafeMutableBufferPointer<Int32>,
D: UnsafeMutableBufferPointer<Float>,
cumsum: UnsafeMutableBufferPointer<Float>,
cumsum2: UnsafeMutableBufferPointer<Float>,
result: UnsafeMutableBufferPointer<Int32>
) {
let device = MTLCopyAllDevices().first!
let commandQueue = device.makeCommandQueue()!
let buffer_n = device.makeBuffer(length: 4)!
let buffer_cols = device.makeBuffer(length: cols.count * 4)!
let buffer_D = device.makeBuffer(length: D.count * 4)!
let buffer_cumsum = device.makeBuffer(length: cumsum.count * 4)!
let buffer_cumsum2 = device.makeBuffer(length: cumsum2.count * 4)!
let buffer_result = device.makeBuffer(length: result.count * 4)!
let _n = buffer_n.contents().assumingMemoryBound(to: Int32.self)
let _cols = buffer_cols.contents().assumingMemoryBound(to: Int32.self)
let _D = buffer_D.contents().assumingMemoryBound(to: Float.self)
let _cumsum = buffer_cumsum.contents().assumingMemoryBound(to: Float.self)
let _cumsum2 = buffer_cumsum2.contents().assumingMemoryBound(to: Float.self)
let _result = buffer_result.contents().assumingMemoryBound(to: Int32.self)
_n.pointee = n
_cols.initialize(from: cols.baseAddress!, count: cols.count)
_D.initialize(from: D.baseAddress!, count: D.count)
_cumsum.initialize(from: cumsum.baseAddress!, count: cumsum.count)
_cumsum2.initialize(from: cumsum2.baseAddress!, count: cumsum2.count)
_result.initialize(from: result.baseAddress!, count: result.count)
let library = try! device.makeLibrary(source: makeShaderSource(), options: nil)
// let library = device.makeDefaultLibrary()!
let function = library.makeFunction(name: "smawk")!
let desc = MTLComputePipelineDescriptor()
desc.computeFunction = function
desc.maxCallStackDepth = 4 + n.trailingZeroBitCount
let pipeline = try! device.makeComputePipelineState(descriptor: desc, options: [], reflection: nil)
let commandBuffer = commandQueue.makeCommandBuffer()!
let encoder = commandBuffer.makeComputeCommandEncoder()!
encoder.setComputePipelineState(pipeline)
encoder.setBuffer(buffer_n, offset: 0, index: 0)
encoder.setBuffer(buffer_cols, offset: 0, index: 1)
encoder.setBuffer(buffer_D, offset: 0, index: 2)
encoder.setBuffer(buffer_cumsum, offset: 0, index: 3)
encoder.setBuffer(buffer_cumsum2, offset: 0, index: 4)
encoder.setBuffer(buffer_result, offset: 0, index: 5)
encoder.dispatchThreads(MTLSizeMake(1, 1, 1), threadsPerThreadgroup: MTLSizeMake(1, 1, 1))
encoder.endEncoding()
commandBuffer.commit()
commandBuffer.waitUntilCompleted()
memcpy(cols.baseAddress!, _cols, cols.count * 4)
memcpy(D.baseAddress!, _cols, D.count * 4)
memcpy(cumsum.baseAddress!, _cols, cumsum.count * 4)
memcpy(cumsum2.baseAddress!, _cols, cumsum2.count * 4)
memcpy(result.baseAddress!, _cols, result.count * 4)
}
func makeShaderSource() -> String {
"""
//
// Kernels.metal
// KMeans
//
// Created by Philip Turner on 6/29/23.
//
#include <metal_stdlib>
using namespace metal;
inline static float _kmeans1d_cost(device float* cumsum, device float* cumsum2, int i, int j)
{
if (j < i)
return 0;
float mu = (cumsum[j + 1] - cumsum[i]) / (j - i + 1);
float result = cumsum2[j + 1] - cumsum2[i];
result += (j - i + 1) * (mu * mu);
result -= (2 * mu) * (cumsum[j + 1] - cumsum[i]);
return result;
}
inline static float _kmeans1d_lookup(device float* D, device float* cumsum, device float* cumsum2, int i, int j)
{
const int col = i < j - 1 ? i : j - 1;
return (col >= 0 ? D[col] : 0) + _kmeans1d_cost(cumsum, cumsum2, j, i);
}
static void _smawk2(int row_start, int row_stride, int row_size, device int* cols, int col_size, device int* reserved, device float* D, device float* cumsum, device float* cumsum2, device int* result)
{
if (row_size == 0)
return;
device int* _cols = cols + col_size;
int _col_size = 0;
int i;
for (i = 0; i < col_size; i++)
{
const int col = cols[i];
for (;;)
{
if (_col_size == 0)
break;
const int row = row_start + row_stride * (_col_size - 1);
if (_kmeans1d_lookup(D, cumsum, cumsum2, row, col) >= _kmeans1d_lookup(D, cumsum, cumsum2, row, _cols[_col_size - 1]))
break;
--_col_size;
}
if (_col_size < row_size)
{
_cols[_col_size] = col;
++_col_size;
}
}
_smawk2(row_start + row_stride, row_stride * 2, row_size / 2, _cols, _col_size, reserved, D, cumsum, cumsum2, result);
// Build the reverse lookup table.
for (i = 0; i < _col_size; i++)
reserved[_cols[i]] = i;
int start = 0;
for (i = 0; i < row_size; i += 2) {
const int row = row_start + i * row_stride;
int stop = _col_size - 1;
if (i < row_size - 1)
{
const int argmin = result[row_start + (i + 1) * row_stride];
stop = reserved[argmin];
}
int argmin = _cols[start];
float min = _kmeans1d_lookup(D, cumsum, cumsum2, row, argmin);
int c;
for (c = start + 1; c <= stop; c++)
{
float value = _kmeans1d_lookup(D, cumsum, cumsum2, row, _cols[c]);
if (value < min) {
argmin = _cols[c];
min = value;
}
}
result[row] = argmin;
start = stop;
}
}
static void _smawk1(int row_start, int row_stride, int row_size, device int* cols, int col_size, device int* reserved, device float* D, device float* cumsum, device float* cumsum2, device int* result)
{
if (row_size == 0)
return;
device int* _cols = cols;
int _col_size = 0;
int i;
for (i = 0; i < col_size; i++)
{
const int col = i;
for (;;)
{
if (_col_size == 0)
break;
const int row = row_start + row_stride * (_col_size - 1);
if (_kmeans1d_lookup(D, cumsum, cumsum2, row, col) >= _kmeans1d_lookup(D, cumsum, cumsum2, row, _cols[_col_size - 1]))
break;
--_col_size;
}
if (_col_size < row_size)
{
_cols[_col_size] = col;
++_col_size;
}
}
_smawk2(row_start + row_stride, row_stride * 2, row_size / 2, _cols, _col_size, reserved, D, cumsum, cumsum2, result);
// Build the reverse lookup table.
for (i = 0; i < _col_size; i++)
reserved[_cols[i]] = i;
int start = 0;
for (i = 0; i < row_size; i += 2) {
const int row = row_start + i * row_stride;
int stop = _col_size - 1;
if (i < row_size - 1)
{
const int argmin = result[row_start + (i + 1) * row_stride];
stop = reserved[argmin];
}
int argmin = _cols[start];
float min = _kmeans1d_lookup(D, cumsum, cumsum2, row, argmin);
int c;
for (c = start + 1; c <= stop; c++)
{
float value = _kmeans1d_lookup(D, cumsum, cumsum2, row, _cols[c]);
if (value < min) {
argmin = _cols[c];
min = value;
}
}
result[row] = argmin;
start = stop;
}
}
static void _smawk0(int row_start, int row_stride, int row_size, device int* cols, int col_size, device int* reserved, device float* D, device float* cumsum, device float* cumsum2, device int* result)
{
if (row_size == 0)
return;
_smawk1(row_start + row_stride, row_stride * 2, row_size / 2, cols, col_size, reserved, D, cumsum, cumsum2, result);
// Build the reverse lookup table.
int start = 0;
int i;
for (i = 0; i < row_size; i += 2) {
const int row = row_start + i * row_stride;
int stop = col_size - 1;
if (i < row_size - 1)
{
const int argmin = result[row_start + (i + 1) * row_stride];
stop = argmin;
}
int argmin = start;
float min = _kmeans1d_lookup(D, cumsum, cumsum2, row, argmin);
int c;
for (c = start + 1; c <= stop; c++)
{
float value = _kmeans1d_lookup(D, cumsum, cumsum2, row, c);
if (value < min) {
argmin = c;
min = value;
}
}
result[row] = argmin;
start = stop;
}
}
// Only works when you dispatch a single GPU thread. Once we verify this works,
// the next step is supporting multiple threads in one invocation, running
// different blocks in parallel.
//
// Another optimization is working directly on half-precision numbers, which
// would let more data fit inside the GPU's tiny caches.
kernel void smawk(constant int &n [[buffer(0)]],
device int* cols [[buffer(1)]],
device float* D [[buffer(2)]],
device float* cumsum [[buffer(3)]],
device float* cumsum2 [[buffer(4)]],
device int* result [[buffer(5)]])
{
_smawk0(0, 1, n, cols + n, n, cols, D, cumsum, cumsum2, result);
}
"""
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment