Created
December 14, 2023 10:04
-
-
Save Narsil/13bfeb66959a0b7e2750a23be9ecc957 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import Metal | |
import MetalPerformanceShadersGraph | |
let type = MTLDataType.float; | |
let dataType = type; | |
var B = 2; | |
var M = 2; | |
var N = 2; | |
var K = 2; | |
var A_trans = false; | |
var B_trans = false; | |
var D_trans = false; | |
var alpha = Float(1.0); | |
var beta = Float(0.0); | |
var batched = B > 1; | |
var fused_activation = false; | |
var fused_bias = false; | |
let constants = MTLFunctionConstantValues() | |
constants.setConstantValue(&M, type: .uint, index: 0) | |
constants.setConstantValue(&N, type: .uint, index: 1) | |
constants.setConstantValue(&K, type: .uint, index: 2) | |
constants.setConstantValue(&A_trans, type: .bool, index: 10) | |
constants.setConstantValue(&B_trans, type: .bool, index: 11) | |
constants.setConstantValue(&D_trans, type: .bool, index: 13) | |
constants.setConstantValue(&alpha, type: .float, index: 20) | |
constants.setConstantValue(&beta, type: .float, index: 21) | |
constants.setConstantValue(&batched, type: .bool, index: 100) | |
constants.setConstantValue(&fused_activation, type: .bool, index: 101) | |
constants.setConstantValue(&fused_bias, type: .bool, index: 50001) | |
var M_simd = UInt16(16) | |
var N_simd = UInt16(16) | |
var K_simd = UInt16(32) | |
var M_splits = UInt16(2) | |
var N_splits = UInt16(2) | |
constants.setConstantValue(&M_simd, type: .ushort, index: 200) | |
constants.setConstantValue(&N_simd, type: .ushort, index: 201) | |
constants.setConstantValue(&K_simd, type: .ushort, index: 202) | |
constants.setConstantValue(&M_splits, type: .ushort, index: 210) | |
constants.setConstantValue(&N_splits, type: .ushort, index: 211) | |
let M_group = M_simd * M_splits | |
let N_group = N_simd * N_splits | |
// Satisfy Metal API validation. | |
#if DEBUG | |
do { | |
var garbage: SIMD4<UInt64> = .zero | |
constants.setConstantValue(&garbage, type: .bool, index: 102) | |
constants.setConstantValue(&garbage, type: .bool, index: 103) | |
constants.setConstantValue(&garbage, type: .bool, index: 113) | |
constants.setConstantValue(&garbage, type: .bool, index: 50000) | |
} | |
#endif | |
let device = MTLCopyAllDevices().first! | |
device.shouldMaximizeConcurrentCompilation = true | |
var libraryURL = URL.init(string: "/Users/nicolas/src/candle/candle-metal-kernels/")!; | |
libraryURL.append(component: "src") | |
libraryURL.append(component: "libMetalFlashAttention.metallib") | |
let library = try! device.makeLibrary(URL: libraryURL) | |
var name: String | |
switch dataType { | |
case .half: name = "hgemm" | |
case .float: name = "sgemm" | |
default: fatalError() | |
} | |
let function = try! library.makeFunction( | |
name: name, constantValues: constants) | |
let A_block_length = M_group * K_simd | |
let B_block_length = K_simd * N_group | |
var blockElements = A_block_length + B_block_length; | |
if (M % 8 != 0) && (N % 8 != 0) { | |
let C_block_length = M_group * N_group; | |
blockElements = max(C_block_length, blockElements) | |
} | |
if fused_bias { | |
if D_trans { | |
blockElements = max(blockElements, M_group) | |
} else { | |
blockElements = max(blockElements, N_group) | |
} | |
} | |
// let blockBytes = blockElements * UInt16(dataType.size) | |
let elementSize = 4 | |
let blockBytes = blockElements * UInt16(elementSize) | |
func ceilDivide(target: Int, granularity: UInt16) -> Int { | |
(target + Int(granularity) - 1) / Int(granularity) | |
} | |
var gridSize = MTLSize( | |
width: ceilDivide(target: N, granularity: N_group), | |
height: ceilDivide(target: M, granularity: M_group), | |
depth: 1) | |
let groupSize = MTLSize( | |
width: Int(32 * M_splits * N_splits), | |
height: 1, | |
depth: 1) | |
let commandQueue = device.makeCommandQueue()! | |
let threadgroupMemoryLength = blockBytes; | |
let rowsA = M; | |
let columnsA = K; | |
let rowsB = K; | |
let columnsB = N; | |
let rowsC = M; | |
let columnsC = N; | |
var arrayA = [Float](repeating: 0, count: B * rowsA * columnsA) | |
var arrayB = [Float](repeating: 0, count: B * rowsB * columnsB) | |
var arrayC = [Float](repeating: 0, count: B * rowsC * columnsC) | |
var arrayD = [Float](repeating: 0, count: B * rowsC * columnsC) | |
for i in 0..<arrayA.count { | |
arrayA[i] = Float(i) | |
} | |
for i in 0..<arrayB.count { | |
arrayB[i] = Float(i) | |
} | |
let bufferA = device.makeBuffer(bytes: arrayA, length: B * rowsA * columnsA * MemoryLayout<Float>.stride, options: [])! | |
let bufferB = device.makeBuffer(bytes: arrayB, length: B * rowsB * columnsB * MemoryLayout<Float>.stride, options: [])! | |
let bufferC = device.makeBuffer(length: B * rowsC * columnsC * MemoryLayout<Float>.stride, options: [])! | |
let bufferD = device.makeBuffer(length: B * rowsC * columnsC * MemoryLayout<Float>.stride, options: [])! | |
let pipeline = try device.makeComputePipelineState(function: function) | |
func call(bufferA: MTLBuffer, bufferB: MTLBuffer, bufferC: MTLBuffer){ | |
let encoder = commandBuffer.makeComputeCommandEncoder(dispatchType: MTLDispatchType.serial)! | |
encoder.setComputePipelineState(pipeline) | |
encoder.setThreadgroupMemoryLength(Int(threadgroupMemoryLength), index: 0) | |
encoder.setBuffer(bufferA, offset: 0, index: 0) | |
encoder.setBuffer(bufferB, offset: 0, index: 1) | |
encoder.setBuffer(bufferC, offset: 0, index: 2) | |
let gridZ: Int = B | |
if batched{ | |
func byteStride(shape: [Int]) -> Int { | |
let rank = shape.count | |
var output = elementSize * shape[rank - 2] * shape[rank - 1] | |
if shape.dropLast(2).reduce(1, *) == 1 { | |
output = 0 | |
} | |
return output | |
} | |
let byteStrideA = M*K*elementSize | |
let byteStrideB = N*K*elementSize | |
let byteStrideC = M*N*elementSize | |
let byteStrideD = 0 | |
withUnsafeTemporaryAllocation( | |
of: SIMD4<UInt64>.self, capacity: gridZ | |
) { buffer in | |
for i in 0..<buffer.count { | |
buffer[i] = SIMD4( | |
UInt64(truncatingIfNeeded: i * byteStrideA), | |
UInt64(truncatingIfNeeded: i * byteStrideB), | |
UInt64(truncatingIfNeeded: i * byteStrideC), | |
UInt64(truncatingIfNeeded: i * byteStrideD)) | |
} | |
let bufferLength = buffer.count * MemoryLayout<SIMD4<UInt64>>.stride | |
assert(MemoryLayout<SIMD4<UInt64>>.stride == 8 * 4) | |
encoder.setBytes(buffer.baseAddress!, length: bufferLength, index: 10) | |
} | |
} | |
gridSize.depth = gridZ | |
encoder.dispatchThreadgroups( | |
gridSize, threadsPerThreadgroup: groupSize | |
) | |
encoder.endEncoding() | |
} | |
var commandBuffer = commandQueue.makeCommandBuffer()! | |
call(bufferA:bufferA, bufferB:bufferB, bufferC:bufferC) | |
commandBuffer.commit() | |
commandBuffer = commandQueue.makeCommandBuffer()! | |
call(bufferA:bufferA, bufferB:bufferC, bufferC:bufferD) | |
commandBuffer.commit() | |
commandBuffer.waitUntilCompleted() | |
var contents = bufferC.contents(); | |
var count = B * rowsA * columnsB; | |
var typedPointer = contents.bindMemory(to: Float.self, capacity: count) | |
var bufferedPointer = UnsafeBufferPointer(start: typedPointer, count: count) | |
print("First matmul is OK", Array(bufferedPointer)) | |
contents = bufferD.contents(); | |
count = B * rowsA * columnsB; | |
typedPointer = contents.bindMemory(to: Float.self, capacity: count) | |
bufferedPointer = UnsafeBufferPointer(start: typedPointer, count: count) | |
print("This should be filled", Array(bufferedPointer)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment