Created
June 29, 2023 21:24
-
-
Save philipturner/83241987a3abf291bd3fbd1aa090b99c to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| // | |
| // main.swift | |
| // KMeans | |
| // | |
| // Created by Philip Turner on 6/29/23. | |
| // | |
| import Metal | |
| example_smawk() | |
| func example_smawk() { | |
| let n = 4096 | |
| let bufferSize = 16384 | |
| var cols: [Int32] = .init(repeating: 0, count: bufferSize) | |
| var D: [Float] = .init(repeating: 0, count: bufferSize) | |
| var cumsum: [Float] = .init(repeating: 0, count: bufferSize) | |
| var cumsum2: [Float] = .init(repeating: 0, count: bufferSize) | |
| var result: [Int32] = .init(repeating: 0, count: bufferSize) | |
| print("'smawk' started.") | |
| cols.withUnsafeMutableBufferPointer { cols in | |
| D.withUnsafeMutableBufferPointer { D in | |
| cumsum.withUnsafeMutableBufferPointer { cumsum in | |
| cumsum2.withUnsafeMutableBufferPointer { cumsum2 in | |
| result.withUnsafeMutableBufferPointer { result in | |
| metal_smawk(n: Int32(n), cols: cols, D: D, cumsum: cumsum, cumsum2: cumsum2, result: result) | |
| } | |
| } | |
| } | |
| } | |
| } | |
| print("'smawk' finished.") | |
| } | |
| // Overwrites the memory of all the pointers you pass in, so don't enter | |
| // anything you expect to stay constant. | |
| func metal_smawk( | |
| n: Int32, | |
| cols: UnsafeMutableBufferPointer<Int32>, | |
| D: UnsafeMutableBufferPointer<Float>, | |
| cumsum: UnsafeMutableBufferPointer<Float>, | |
| cumsum2: UnsafeMutableBufferPointer<Float>, | |
| result: UnsafeMutableBufferPointer<Int32> | |
| ) { | |
| let device = MTLCopyAllDevices().first! | |
| let commandQueue = device.makeCommandQueue()! | |
| let buffer_n = device.makeBuffer(length: 4)! | |
| let buffer_cols = device.makeBuffer(length: cols.count * 4)! | |
| let buffer_D = device.makeBuffer(length: D.count * 4)! | |
| let buffer_cumsum = device.makeBuffer(length: cumsum.count * 4)! | |
| let buffer_cumsum2 = device.makeBuffer(length: cumsum2.count * 4)! | |
| let buffer_result = device.makeBuffer(length: result.count * 4)! | |
| let _n = buffer_n.contents().assumingMemoryBound(to: Int32.self) | |
| let _cols = buffer_cols.contents().assumingMemoryBound(to: Int32.self) | |
| let _D = buffer_D.contents().assumingMemoryBound(to: Float.self) | |
| let _cumsum = buffer_cumsum.contents().assumingMemoryBound(to: Float.self) | |
| let _cumsum2 = buffer_cumsum2.contents().assumingMemoryBound(to: Float.self) | |
| let _result = buffer_result.contents().assumingMemoryBound(to: Int32.self) | |
| _n.pointee = n | |
| _cols.initialize(from: cols.baseAddress!, count: cols.count) | |
| _D.initialize(from: D.baseAddress!, count: D.count) | |
| _cumsum.initialize(from: cumsum.baseAddress!, count: cumsum.count) | |
| _cumsum2.initialize(from: cumsum2.baseAddress!, count: cumsum2.count) | |
| _result.initialize(from: result.baseAddress!, count: result.count) | |
| let library = try! device.makeLibrary(source: makeShaderSource(), options: nil) | |
| // let library = device.makeDefaultLibrary()! | |
| let function = library.makeFunction(name: "smawk")! | |
| let desc = MTLComputePipelineDescriptor() | |
| desc.computeFunction = function | |
| desc.maxCallStackDepth = 4 + n.trailingZeroBitCount | |
| let pipeline = try! device.makeComputePipelineState(descriptor: desc, options: [], reflection: nil) | |
| let commandBuffer = commandQueue.makeCommandBuffer()! | |
| let encoder = commandBuffer.makeComputeCommandEncoder()! | |
| encoder.setComputePipelineState(pipeline) | |
| encoder.setBuffer(buffer_n, offset: 0, index: 0) | |
| encoder.setBuffer(buffer_cols, offset: 0, index: 1) | |
| encoder.setBuffer(buffer_D, offset: 0, index: 2) | |
| encoder.setBuffer(buffer_cumsum, offset: 0, index: 3) | |
| encoder.setBuffer(buffer_cumsum2, offset: 0, index: 4) | |
| encoder.setBuffer(buffer_result, offset: 0, index: 5) | |
| encoder.dispatchThreads(MTLSizeMake(1, 1, 1), threadsPerThreadgroup: MTLSizeMake(1, 1, 1)) | |
| encoder.endEncoding() | |
| commandBuffer.commit() | |
| commandBuffer.waitUntilCompleted() | |
| memcpy(cols.baseAddress!, _cols, cols.count * 4) | |
| memcpy(D.baseAddress!, _cols, D.count * 4) | |
| memcpy(cumsum.baseAddress!, _cols, cumsum.count * 4) | |
| memcpy(cumsum2.baseAddress!, _cols, cumsum2.count * 4) | |
| memcpy(result.baseAddress!, _cols, result.count * 4) | |
| } | |
| func makeShaderSource() -> String { | |
| """ | |
| // | |
| // Kernels.metal | |
| // KMeans | |
| // | |
| // Created by Philip Turner on 6/29/23. | |
| // | |
| #include <metal_stdlib> | |
| using namespace metal; | |
| inline static float _kmeans1d_cost(device float* cumsum, device float* cumsum2, int i, int j) | |
| { | |
| if (j < i) | |
| return 0; | |
| float mu = (cumsum[j + 1] - cumsum[i]) / (j - i + 1); | |
| float result = cumsum2[j + 1] - cumsum2[i]; | |
| result += (j - i + 1) * (mu * mu); | |
| result -= (2 * mu) * (cumsum[j + 1] - cumsum[i]); | |
| return result; | |
| } | |
| inline static float _kmeans1d_lookup(device float* D, device float* cumsum, device float* cumsum2, int i, int j) | |
| { | |
| const int col = i < j - 1 ? i : j - 1; | |
| return (col >= 0 ? D[col] : 0) + _kmeans1d_cost(cumsum, cumsum2, j, i); | |
| } | |
| static void _smawk2(int row_start, int row_stride, int row_size, device int* cols, int col_size, device int* reserved, device float* D, device float* cumsum, device float* cumsum2, device int* result) | |
| { | |
| if (row_size == 0) | |
| return; | |
| device int* _cols = cols + col_size; | |
| int _col_size = 0; | |
| int i; | |
| for (i = 0; i < col_size; i++) | |
| { | |
| const int col = cols[i]; | |
| for (;;) | |
| { | |
| if (_col_size == 0) | |
| break; | |
| const int row = row_start + row_stride * (_col_size - 1); | |
| if (_kmeans1d_lookup(D, cumsum, cumsum2, row, col) >= _kmeans1d_lookup(D, cumsum, cumsum2, row, _cols[_col_size - 1])) | |
| break; | |
| --_col_size; | |
| } | |
| if (_col_size < row_size) | |
| { | |
| _cols[_col_size] = col; | |
| ++_col_size; | |
| } | |
| } | |
| _smawk2(row_start + row_stride, row_stride * 2, row_size / 2, _cols, _col_size, reserved, D, cumsum, cumsum2, result); | |
| // Build the reverse lookup table. | |
| for (i = 0; i < _col_size; i++) | |
| reserved[_cols[i]] = i; | |
| int start = 0; | |
| for (i = 0; i < row_size; i += 2) { | |
| const int row = row_start + i * row_stride; | |
| int stop = _col_size - 1; | |
| if (i < row_size - 1) | |
| { | |
| const int argmin = result[row_start + (i + 1) * row_stride]; | |
| stop = reserved[argmin]; | |
| } | |
| int argmin = _cols[start]; | |
| float min = _kmeans1d_lookup(D, cumsum, cumsum2, row, argmin); | |
| int c; | |
| for (c = start + 1; c <= stop; c++) | |
| { | |
| float value = _kmeans1d_lookup(D, cumsum, cumsum2, row, _cols[c]); | |
| if (value < min) { | |
| argmin = _cols[c]; | |
| min = value; | |
| } | |
| } | |
| result[row] = argmin; | |
| start = stop; | |
| } | |
| } | |
| static void _smawk1(int row_start, int row_stride, int row_size, device int* cols, int col_size, device int* reserved, device float* D, device float* cumsum, device float* cumsum2, device int* result) | |
| { | |
| if (row_size == 0) | |
| return; | |
| device int* _cols = cols; | |
| int _col_size = 0; | |
| int i; | |
| for (i = 0; i < col_size; i++) | |
| { | |
| const int col = i; | |
| for (;;) | |
| { | |
| if (_col_size == 0) | |
| break; | |
| const int row = row_start + row_stride * (_col_size - 1); | |
| if (_kmeans1d_lookup(D, cumsum, cumsum2, row, col) >= _kmeans1d_lookup(D, cumsum, cumsum2, row, _cols[_col_size - 1])) | |
| break; | |
| --_col_size; | |
| } | |
| if (_col_size < row_size) | |
| { | |
| _cols[_col_size] = col; | |
| ++_col_size; | |
| } | |
| } | |
| _smawk2(row_start + row_stride, row_stride * 2, row_size / 2, _cols, _col_size, reserved, D, cumsum, cumsum2, result); | |
| // Build the reverse lookup table. | |
| for (i = 0; i < _col_size; i++) | |
| reserved[_cols[i]] = i; | |
| int start = 0; | |
| for (i = 0; i < row_size; i += 2) { | |
| const int row = row_start + i * row_stride; | |
| int stop = _col_size - 1; | |
| if (i < row_size - 1) | |
| { | |
| const int argmin = result[row_start + (i + 1) * row_stride]; | |
| stop = reserved[argmin]; | |
| } | |
| int argmin = _cols[start]; | |
| float min = _kmeans1d_lookup(D, cumsum, cumsum2, row, argmin); | |
| int c; | |
| for (c = start + 1; c <= stop; c++) | |
| { | |
| float value = _kmeans1d_lookup(D, cumsum, cumsum2, row, _cols[c]); | |
| if (value < min) { | |
| argmin = _cols[c]; | |
| min = value; | |
| } | |
| } | |
| result[row] = argmin; | |
| start = stop; | |
| } | |
| } | |
| static void _smawk0(int row_start, int row_stride, int row_size, device int* cols, int col_size, device int* reserved, device float* D, device float* cumsum, device float* cumsum2, device int* result) | |
| { | |
| if (row_size == 0) | |
| return; | |
| _smawk1(row_start + row_stride, row_stride * 2, row_size / 2, cols, col_size, reserved, D, cumsum, cumsum2, result); | |
| // Build the reverse lookup table. | |
| int start = 0; | |
| int i; | |
| for (i = 0; i < row_size; i += 2) { | |
| const int row = row_start + i * row_stride; | |
| int stop = col_size - 1; | |
| if (i < row_size - 1) | |
| { | |
| const int argmin = result[row_start + (i + 1) * row_stride]; | |
| stop = argmin; | |
| } | |
| int argmin = start; | |
| float min = _kmeans1d_lookup(D, cumsum, cumsum2, row, argmin); | |
| int c; | |
| for (c = start + 1; c <= stop; c++) | |
| { | |
| float value = _kmeans1d_lookup(D, cumsum, cumsum2, row, c); | |
| if (value < min) { | |
| argmin = c; | |
| min = value; | |
| } | |
| } | |
| result[row] = argmin; | |
| start = stop; | |
| } | |
| } | |
| // Only works when you dispatch a single GPU thread. Once we verify this works, | |
| // the next step is supporting multiple threads in one invocation, running | |
| // different blocks in parallel. | |
| // | |
| // Another optimization is working directly on half-precision numbers, which | |
| // would let more data fit inside the GPU's tiny caches. | |
| kernel void smawk(constant int &n [[buffer(0)]], | |
| device int* cols [[buffer(1)]], | |
| device float* D [[buffer(2)]], | |
| device float* cumsum [[buffer(3)]], | |
| device float* cumsum2 [[buffer(4)]], | |
| device int* result [[buffer(5)]]) | |
| { | |
| _smawk0(0, 1, n, cols + n, n, cols, D, cumsum, cumsum2, result); | |
| } | |
| """ | |
| } |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment