Last active
June 7, 2023 19:05
-
-
Save philipturner/69d87fa089e43a7a1cda6627c0f388ec to your computer and use it in GitHub Desktop.
Simulate the bandwidth achieved while executing feedforward layers in LLaMA
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import Metal | |
| // MARK: - Usage | |
| // Usage: | |
| // | |
| // 1) Install Xcode from the Mac App Store | |
| // | |
| // 2) From the command line: | |
| // | |
| // touch GEMV_bandwidth.swift | |
| // | |
| // 3) Copy the contents of this script into GEMV_bandwidth.swift | |
| // | |
| // 4) Choose whether to profile LLaMA.cpp's algorithm or the optimized GEMV | |
| // | |
| let feedforwardAlgorithm: Algorithm = { | |
| guard CommandLine.arguments.count == 2, | |
| CommandLine.arguments[1] == "--llama-cpp" || | |
| CommandLine.arguments[1] == "--llama-cpp-new" || | |
| CommandLine.arguments[1] == "--gemv-quantized-i4" else { | |
| print(""" | |
| Usage: swift GEMV_bandwidth.swift [--llama-cpp | --llama-cpp-new | \ | |
| --gemv-quantized-i4] | |
| """) | |
| exit(0) | |
| } | |
| switch CommandLine.arguments[1] { | |
| case "--llama-cpp": return Algorithm.LLaMA_cpp | |
| case "--llama-cpp-new": return Algorithm.LLaMA_cpp_new | |
| case "--gemv-quantized-i4": return Algorithm.GEMV_quantized_i4 | |
| default: fatalError("This should never happen.") | |
| } | |
| }() | |
| if CommandLine.arguments[1] == "--gemv-quantized-i4" { | |
| print(""" | |
| \u{1b}[0;36mWARNING:\u{1b}[0m \u{1b}[0;32m--gemv-quantized-i4\u{1b}[0m has \ | |
| a bug. Do not use it to estimate maximum achievable bandwidth or minimum \ | |
| achievable latency. | |
| """) | |
| } | |
| // | |
| // 5) From the command line (the last kernel has a bug): | |
| // | |
| // swift GEMV_bandwidth.swift --llama-cpp | |
| // swift GEMV_bandwidth.swift --llama-cpp-new | |
| // swift GEMV_bandwidth.swift --gemv-quantized-i4 | |
| // | |
| // 6) Look for the bandwidth in GB/s | |
| // | |
| // % swift GEMV_bandwidth.swift --gemv-quantized-i4 | |
| // 316.4 GB/s | |
| // ≥11.4 ms/token | |
| // | |
| // 7) Repeat steps 5 - 6 three times, and report every number you get. | |
| // | |
| // % swift GEMV_bandwidth.swift --gemv-quantized-i4 | |
| // 316.4 GB/s | |
| // ≥11.4 ms/token | |
| // % swift GEMV_bandwidth.swift --gemv-quantized-i4 | |
| // 317.2 GB/s | |
| // ≥11.4 ms/token | |
| // % swift GEMV_bandwidth.swift --gemv-quantized-i4 | |
| // 315.6 GB/s | |
| // ≥11.4 ms/token | |
| // MARK: - Shader Parameters | |
| // How many times to sample bandwidth while warming the caches. | |
| let numTrials: Int = 10 | |
| // The algorithm used to multiply 4-bit matrices with 16/32-bit vectors. | |
| enum Algorithm { | |
| // slower kernel originally in LLaMA.cpp | |
| case LLaMA_cpp | |
| // kernel currently in the GGML Metal backend | |
| case LLaMA_cpp_new | |
| // faster kernel, ONLY RUNS ON M1/M2/Pro/Max/Ultra | |
| // WARNING: This has a bug, so it overestimates bandwidth. | |
| case GEMV_quantized_i4 | |
| var gpuFunctionName: String { | |
| switch self { | |
| case .LLaMA_cpp: | |
| return "dequantize_mul_mat_vec_q4_0" | |
| case .LLaMA_cpp_new: | |
| return "kernel_mul_mat_q4_0_f32" | |
| case .GEMV_quantized_i4: | |
| return "gemv_quantized_i4" | |
| } | |
| } | |
| } | |
| // MARK: - Bandwidth Profiling | |
| try testLLaMA() | |
| func testLLaMA() throws { | |
| let devices = MTLCopyAllDevices() | |
| let device = devices.first(where: { !$0.isLowPower }) ?? devices.first! | |
| let commandQueue = device.makeCommandQueue()! | |
| let library = try device.makeLibrary(source: getShaderSource(), options: nil) | |
| let constants = MTLFunctionConstantValues() | |
| var ncols: UInt32 = 4096 | |
| constants.setConstantValue(&ncols, type: .uint, index: 0) | |
| let functionName = feedforwardAlgorithm.gpuFunctionName | |
| var function = try! library.makeFunction( | |
| name: functionName, constantValues: constants) | |
| let w13Pipeline = try! device.makeComputePipelineState(function: function) | |
| ncols = 4096 * 4 | |
| constants.setConstantValue(&ncols, type: .uint, index: 0) | |
| function = try! library.makeFunction( | |
| name: functionName, constantValues: constants) | |
| let w2Pipeline = try! device.makeComputePipelineState(function: function) | |
| let matrixElements = 4096 * 4096 * 4 | |
| let matrixSize = matrixElements / 2 + (matrixElements / 32) * 2 | |
| let vectorSize = 4096 * 4 | |
| assert(matrixSize == 4096 * 16384 / 2 * 9 / 8) | |
| // Run all 32 layers in quick succession to find asymptotic maximum bandwidth. | |
| struct Vectors { | |
| var x: MTLBuffer | |
| var w1Val: MTLBuffer // quadruple the vector size | |
| var w3Val: MTLBuffer // quadruple the vector size | |
| var output: MTLBuffer | |
| init(device: MTLDevice, vectorSize: Int) { | |
| x = device.makeBuffer(length: vectorSize)! | |
| w1Val = device.makeBuffer(length: vectorSize * 4)! | |
| w3Val = device.makeBuffer(length: vectorSize * 4)! | |
| output = device.makeBuffer(length: vectorSize)! | |
| } | |
| } | |
| struct Context { | |
| var vectors: Vectors | |
| var w13Pipeline: MTLComputePipelineState | |
| var w2Pipeline: MTLComputePipelineState | |
| } | |
| // Don't do anything with the vectors written back to RAM. | |
| struct FeedForward { | |
| var matrixSize: Int | |
| var weights1: MTLBuffer | |
| var weights2: MTLBuffer | |
| var weights3: MTLBuffer | |
| init(device: MTLDevice, matrixSize: Int) { | |
| self.matrixSize = matrixSize | |
| weights1 = device.makeBuffer(length: matrixSize)! | |
| weights2 = device.makeBuffer(length: matrixSize)! | |
| weights3 = device.makeBuffer(length: matrixSize)! | |
| } | |
| var totalMemory: Int { | |
| weights1.length + weights2.length + weights3.length | |
| } | |
| func encode(encoder: MTLComputeCommandEncoder, context ctx: Context) { | |
| var simdRowStride: Int | |
| var simdsPerGroup: Int | |
| switch feedforwardAlgorithm { | |
| case .LLaMA_cpp: | |
| simdRowStride = 1 | |
| simdsPerGroup = 1 | |
| case .LLaMA_cpp_new: | |
| simdRowStride = -1 | |
| simdsPerGroup = -1 | |
| case .GEMV_quantized_i4: | |
| simdRowStride = 4 | |
| simdsPerGroup = 4 | |
| } | |
| if feedforwardAlgorithm == .LLaMA_cpp_new { | |
| func encodeInteger(_ value: Int, index: Int) { | |
| var valueCopy = value | |
| encoder.setBytes(&valueCopy, length: 8, index: index) | |
| /* | |
| constant int64_t & ne00 [[buffer(3)]], | |
| constant int64_t & ne01 [[buffer(4)]], | |
| constant uint64_t & nb00 [[buffer(5)]], | |
| constant uint64_t & nb01 [[buffer(6)]], | |
| constant uint64_t & nb02 [[buffer(7)]], | |
| constant int64_t & ne10 [[buffer(8)]], | |
| constant int64_t & ne11 [[buffer(9)]], | |
| constant uint64_t & nb10 [[buffer(10)]], | |
| constant uint64_t & nb11 [[buffer(11)]], | |
| constant uint64_t & nb12 [[buffer(12)]], | |
| constant int64_t & ne0 [[buffer(13)]], | |
| constant int64_t & ne1 [[buffer(14)]], | |
| */ | |
| } | |
| var nrows: UInt32 = 4096 * 4 | |
| func encodeCommon() { | |
| encodeInteger(-1, index: 4) | |
| encodeInteger(-1, index: 5) | |
| encodeInteger(-1, index: 6) | |
| encodeInteger(-1, index: 7) | |
| encodeInteger(-1, index: 9) | |
| encodeInteger(-1, index: 10) | |
| encodeInteger(-1, index: 11) | |
| encodeInteger(-1, index: 12) | |
| encodeInteger(-1, index: 14) | |
| encoder.dispatchThreadgroups( | |
| MTLSizeMake(Int(nrows), 1, 1), | |
| threadsPerThreadgroup: MTLSizeMake(8, 4, 1)) | |
| } | |
| encoder.setComputePipelineState(ctx.w13Pipeline) | |
| encoder.setThreadgroupMemoryLength(32 * 4, index: 0) | |
| encoder.setBuffer(weights1, offset: 0, index: 0) | |
| encoder.setBuffer(ctx.vectors.x, offset: 0, index: 1) | |
| encoder.setBuffer(ctx.vectors.w1Val, offset: 0, index: 2) | |
| encodeInteger(4096, index: 3) // ncols | |
| encodeInteger(4096, index: 8) // ncols | |
| encodeInteger(1, index: 13) // always 1 | |
| encodeCommon() | |
| encoder.setComputePipelineState(ctx.w13Pipeline) | |
| encoder.setThreadgroupMemoryLength(32 * 4, index: 0) | |
| encoder.setBuffer(weights3, offset: 0, index: 0) | |
| encoder.setBuffer(ctx.vectors.x, offset: 0, index: 1) | |
| encoder.setBuffer(ctx.vectors.w3Val, offset: 0, index: 2) | |
| encodeInteger(4096, index: 3) // ncols | |
| encodeInteger(4096, index: 8) // ncols | |
| encodeInteger(1, index: 13) // always 1 | |
| encodeCommon() | |
| nrows = 4096 | |
| encoder.setComputePipelineState(ctx.w2Pipeline) | |
| encoder.setThreadgroupMemoryLength(32 * 4, index: 0) | |
| encoder.setBuffer(weights2, offset: 0, index: 0) | |
| encoder.setBuffer(ctx.vectors.w3Val, offset: 0, index: 1) | |
| encoder.setBuffer(ctx.vectors.output, offset: 0, index: 2) | |
| encodeInteger(4096 * 4, index: 3) // ncols | |
| encodeInteger(4096 * 4, index: 8) // ncols | |
| encodeInteger(1, index: 13) // always 1 | |
| encodeCommon() | |
| } else { | |
| encoder.setComputePipelineState(ctx.w13Pipeline) | |
| encoder.setThreadgroupMemoryLength(4 * 32, index: 0) | |
| encoder.setBuffer(weights1, offset: 0, index: 0) | |
| encoder.setBuffer(weights1, offset: matrixSize * 8 / 9, index: 1) | |
| encoder.setBuffer(ctx.vectors.x, offset: 0, index: 2) | |
| encoder.setBuffer(ctx.vectors.w1Val, offset: 0, index: 3) | |
| var ncols: UInt32 = 4096 | |
| var nrows: UInt32 = 4096 * 4 | |
| encoder.setBytes(&ncols, length: 4, index: 4) | |
| encoder.dispatchThreadgroups( | |
| MTLSizeMake(Int(nrows) / simdRowStride / simdsPerGroup, 1, 1), | |
| threadsPerThreadgroup: MTLSizeMake(32 * simdsPerGroup, 1, 1)) | |
| encoder.setComputePipelineState(ctx.w13Pipeline) | |
| encoder.setThreadgroupMemoryLength(4 * 32, index: 0) | |
| encoder.setBuffer(weights3, offset: 0, index: 0) | |
| encoder.setBuffer(weights3, offset: matrixSize * 8 / 9, index: 1) | |
| encoder.setBuffer(ctx.vectors.x, offset: 0, index: 2) | |
| encoder.setBuffer(ctx.vectors.w3Val, offset: 0, index: 3) | |
| encoder.setBytes(&ncols, length: 4, index: 4) | |
| encoder.dispatchThreadgroups( | |
| MTLSizeMake(Int(nrows) / simdRowStride / simdsPerGroup, 1, 1), | |
| threadsPerThreadgroup: MTLSizeMake(32 * simdsPerGroup, 1, 1)) | |
| encoder.setComputePipelineState(ctx.w2Pipeline) | |
| encoder.setThreadgroupMemoryLength(4 * 32, index: 0) | |
| encoder.setBuffer(weights2, offset: 0, index: 0) | |
| encoder.setBuffer(weights2, offset: matrixSize * 8 / 9, index: 1) | |
| encoder.setBuffer(ctx.vectors.w3Val, offset: 0, index: 2) | |
| encoder.setBuffer(ctx.vectors.output, offset: 0, index: 3) | |
| ncols = 4096 * 4 | |
| nrows = 4096 | |
| encoder.setBytes(&ncols, length: 4, index: 4) | |
| encoder.dispatchThreadgroups( | |
| MTLSizeMake(Int(nrows) / simdRowStride / simdsPerGroup, 1, 1), | |
| threadsPerThreadgroup: MTLSizeMake(32 * simdsPerGroup, 1, 1)) | |
| } | |
| } | |
| } | |
| let numLayers = 32 | |
| let vectors = Vectors(device: device, vectorSize: vectorSize) | |
| let context = Context( | |
| vectors: vectors, w13Pipeline: w13Pipeline, w2Pipeline: w2Pipeline) | |
| var layers: [FeedForward] = [] | |
| for _ in 0..<numLayers { | |
| let layer = FeedForward(device: device, matrixSize: matrixSize) | |
| layers.append(layer) | |
| } | |
| var maxBandwidth: Float = 0 | |
| for _ in 0..<numTrials { | |
| let commandBuffer = commandQueue.makeCommandBuffer()! | |
| let encoder = commandBuffer.makeComputeCommandEncoder()! | |
| for layer in layers { | |
| layer.encode(encoder: encoder, context: context) | |
| } | |
| encoder.endEncoding() | |
| commandBuffer.commit() | |
| commandBuffer.waitUntilCompleted() | |
| let time = commandBuffer.gpuEndTime - commandBuffer.gpuStartTime | |
| let bytes = Double(numLayers * layers[0].totalMemory) | |
| let bandwidth = bytes / time / 1e9 | |
| maxBandwidth = max(maxBandwidth, Float(bandwidth)) | |
| } | |
| print(""" | |
| \u{1b}[0;33m\(String(format: "%.1f", Float(maxBandwidth))) GB/s\u{1b}[0m | |
| """) | |
| let modelSize = Float(numLayers * layers[0].totalMemory) | |
| let latency = modelSize / (maxBandwidth * 1e9) | |
| let usPerToken = Int(latency / 1e-6) | |
| let ms = usPerToken / 1000 | |
| let us = usPerToken % 1000 | |
| print("\u{1b}[0;33m≥\(ms).\(us / 100) ms/token\u{1b}[0m") | |
| } | |
| // Bypass a bug in the Swift interpreter, where global variables declared later | |
| // in the file are initialized to empty or zero. | |
| func getShaderSource() -> String { | |
| return """ | |
| #define COMPILE_LEGACY_LLAMA_CPP_OPENCL_SHADER 1 | |
| #define COMPILE_NEW_LLAMA_CPP_METAL_SHADER 1 | |
| #include <metal_stdlib> | |
| using namespace metal; | |
| // Perform a feedforward layer of LLaMA-6.7B. Ensure you are cycling through | |
| // 32 instances of the layer weights, otherwise they will fall into the | |
| // system-level cache. | |
| // | |
| // Inference from both the LLaMA.cpp format and a bandwidth-optimized format. | |
| #if COMPILE_LEGACY_LLAMA_CPP_OPENCL_SHADER | |
| // Reference implementation from LLaMA.cpp. The custom implementation stores | |
| // weights in a different format. | |
| #define QK4_0 32 | |
| #define QR4_0 2 | |
| struct __attribute__ ((packed)) __old_block_q4_0 | |
| { | |
| half d; | |
| uint8_t qs[QK4_0 / 2]; | |
| }; | |
| void dequantize_q4_0 | |
| ( | |
| const device __old_block_q4_0* x, | |
| const int ib, | |
| const int iqs, | |
| thread float* v0, | |
| thread float* v1) | |
| { | |
| const float d = float(x[ib].d); | |
| const uint8_t vui = x[ib].qs[iqs]; | |
| const int8_t vi0 = vui & 0xF; | |
| const int8_t vi1 = vui >> 4; | |
| *v0 = (vi0 - 8)*d; | |
| *v1 = (vi1 - 8)*d; | |
| } | |
| // Original: | |
| // 88.5 GB/s | |
| // Don't read out-of-bounds vector data: | |
| // 96.0 GB/s | |
| kernel void dequantize_mul_mat_vec_q4_0 | |
| ( | |
| device __old_block_q4_0* x [[buffer(0)]], | |
| threadgroup float* tmp [[threadgroup(0)]], | |
| device float* y [[buffer(2)]], | |
| device float* dst [[buffer(3)]], | |
| constant uint &ncols [[buffer(4)]], | |
| uint block_size [[threads_per_threadgroup]], | |
| uint global_id [[thread_position_in_grid]], | |
| uint local_id [[thread_position_in_threadgroup]]) | |
| { | |
| const uint row = global_id / block_size; | |
| const uint qk = QK4_0; | |
| const uint qr = QR4_0; | |
| const int y_offset = qr == 1 ? 1 : qk/2; | |
| tmp[local_id] = 0; | |
| for (uint i = 0; i < ncols/block_size; i += 2) { | |
| const uint col = i*block_size + 2*local_id; | |
| const uint ib = (row*ncols + col)/qk; // block index | |
| const uint iqs = (col%qk)/qr; // quant index | |
| const uint iybs = col - col%qk; // y block start index | |
| // dequantize | |
| float v0, v1; | |
| dequantize_q4_0(x, ib, iqs, &v0, &v1); | |
| // matrix multiplication | |
| tmp[local_id] += v0 * y[iybs + iqs + 0]; | |
| tmp[local_id] += v1 * y[iybs + iqs + y_offset]; | |
| } | |
| // sum up partial sums and write back result | |
| threadgroup_barrier(mem_flags::mem_threadgroup); | |
| for (uint s=block_size/2; s>0; s>>=1) { | |
| if (local_id < s) { | |
| tmp[local_id] += tmp[local_id + s]; | |
| } | |
| threadgroup_barrier(mem_flags::mem_threadgroup); | |
| } | |
| if (local_id == 0) { | |
| dst[row] = tmp[0]; | |
| } | |
| } | |
| #undef QK4_0 | |
| #undef QKR_0 | |
| #endif | |
| #if COMPILE_NEW_LLAMA_CPP_METAL_SHADER | |
| #define QK4_0 32 | |
| #define QR4_0 2 | |
| typedef struct { | |
| half d; // delta | |
| uint8_t qs[QK4_0 / 2]; // nibbles / quants | |
| } block_q4_0; | |
| kernel void kernel_mul_mat_q4_0_f32( | |
| device const void * src0 [[buffer(0)]], | |
| device const float * src1 [[buffer(1)]], | |
| device float * dst [[buffer(2)]], | |
| constant int64_t & ne00 [[buffer(3)]], | |
| constant int64_t & ne01 [[buffer(4)]], | |
| constant uint64_t & nb00 [[buffer(5)]], | |
| constant uint64_t & nb01 [[buffer(6)]], | |
| constant uint64_t & nb02 [[buffer(7)]], | |
| constant int64_t & ne10 [[buffer(8)]], | |
| constant int64_t & ne11 [[buffer(9)]], | |
| constant uint64_t & nb10 [[buffer(10)]], | |
| constant uint64_t & nb11 [[buffer(11)]], | |
| constant uint64_t & nb12 [[buffer(12)]], | |
| constant int64_t & ne0 [[buffer(13)]], | |
| constant int64_t & ne1 [[buffer(14)]], | |
| threadgroup float * sum [[threadgroup(0)]], | |
| uint2 tgpig[[threadgroup_position_in_grid]], | |
| uint2 tpig[[thread_position_in_grid]], | |
| uint2 tpitg[[thread_position_in_threadgroup]], | |
| uint2 tptg[[threads_per_threadgroup]]) { | |
| const int nb = ne00/QK4_0; | |
| const int64_t r0 = tgpig.x; | |
| const int64_t r1 = tgpig.y; | |
| device const block_q4_0 * x = (device const block_q4_0 *) src0 + r0*nb; | |
| device const float * y = (device const float *) src1 + r1*ne10; | |
| const uint nth = tptg.x*tptg.y; | |
| const uint ith = tptg.y*tpitg.x + tpitg.y; | |
| sum[ith] = 0.0f; | |
| for (int i = tpitg.x; i < nb; i += tptg.x) { | |
| device const uchar4 * x0p = (device const uchar4 *) (x + i)->qs; | |
| device const float4 * y0p = (device const float4 *) (y + i*QK4_0); | |
| const float d = (float)((x + i)->d); | |
| const uchar4 x0v = *(x0p + tpitg.y); | |
| const float4 y0v = *(y0p + tpitg.y + 0); | |
| const float4 y1v = *(y0p + tpitg.y + 4); | |
| float acc = 0.0f; | |
| for (int j = 0; j < 4; ++j) { | |
| const int x0 = x0v[j] & 0x0F; | |
| const int x1 = x0v[j] >> 4; | |
| const float y0 = y0v[j]; | |
| const float y1 = y1v[j]; | |
| acc += (x0 - 8)*y0 + (x1 - 8)*y1; | |
| } | |
| sum[ith] += acc*d; | |
| } | |
| // accumulate the sum from all threads in the threadgroup | |
| threadgroup_barrier(mem_flags::mem_threadgroup); | |
| for (uint i = nth/2; i > 0; i /= 2) { | |
| if (ith < i) { | |
| sum[ith] += sum[ith + i]; | |
| } | |
| threadgroup_barrier(mem_flags::mem_threadgroup); | |
| } | |
| if (ith == 0) { | |
| dst[r1*ne0 + r0] = sum[0]; | |
| } | |
| } | |
| #undef QK4_0 | |
| #undef QKR_0 | |
| #endif | |
| // Original: | |
| // 88.5 GB/s | |
| // Switch from threadgroup to simdgroup sum: | |
| // 90.2 GB/s | |
| // Deinterleave the weights and scales: | |
| // 98.5 GB/s | |
| // Hard-code shader parameters and remove `if (local_id == 0)`: | |
| // 139.4 GB/s | |
| // Directly index `uint8_t` instead of a struct: | |
| // 143.6 GB/s | |
| // Don't read out-of-bounds vector data: | |
| // 172.6 GB/s | |
| // Read input vectors as half-precision: | |
| // 193.8 GB/s | |
| // Coalesce the accesses to y and read the correct value from `weights`: | |
| // 199.4 GB/s | |
| // Change threadgroup size from 32 to 64: | |
| // 210.4 GB/s | |
| // Change threadgroup size from 32 to 128: | |
| // 211.7 GB/s | |
| // Two rows per simd: | |
| // 225.4 GB/s | |
| // Four rows per simd: | |
| // 226.7 GB/s | |
| // Unroll two iterations of the loop / un-duplicate scale reads: | |
| // 243.3 GB/s | |
| // Coalesce two Y reads: | |
| // 245.7 GB/s | |
| // Perform both X reads at the same time: | |
| // 253.9 GB/s | |
| // Unroll four iterations of the loop: | |
| // (BAD DATA) 292.4 GB/s | |
| // Coalesce X reads: | |
| // (BAD DATA) 353.6 GB/s | |
| // Coalesce four Y reads and use the correct index within a row: | |
| // (BAD DATA) 374.8 GB/s | |
| // Change how the buffers are indexed: | |
| // (BAD DATA) 406.9 GB/s | |
| // Use the correct value for 'i': | |
| // 304.0 GB/s | |
| // Optimize how 'vui' is stored in registers: | |
| // 306.1 GB/s | |
| // Optimize the generation of the index for scales: | |
| // 319.3 GB/s | |
| constant uint ncols [[function_constant(0)]]; | |
| // SIMD shuffle instructions require Metal 3 support, although the | |
| // function itself is heavily optimized for the Apple 7 architecture. | |
| // | |
| // TODO: Support different quantization formats through function constants. | |
| kernel void gemv_quantized_i4 | |
| ( | |
| device uchar4 *weights [[buffer(0)]], | |
| device half2 *scales [[buffer(1)]], | |
| device half *y [[buffer(2)]], | |
| device half *dst [[buffer(3)]], | |
| uint tid [[thread_position_in_grid]]) | |
| { | |
| // 8-wide groupings of threads, each thread reads 8 values per iteration. | |
| // 'groupings of threads' != 'threadgroups' | |
| #define WEIGHTS_PER_UINT 8 | |
| #define GROUPING_SIZE 8 | |
| uint row = tid / GROUPING_SIZE; | |
| ushort local_id = tid % GROUPING_SIZE; | |
| float acc = 0; | |
| // Changing this to a `while` loop harms performance. Perhaps it triggers a | |
| // separate assembly instruction for control flow. | |
| for (uint i = 0; i < ncols;) { | |
| uchar4 vui = weights[i / WEIGHTS_PER_UINT + local_id]; | |
| const uint blocks_in_row = ncols / 32; | |
| half2 d = scales[row * blocks_in_row / 2 + i / 32 / 2]; | |
| { | |
| half4 y_value = *(device half4*)(y + i + 4 * local_id); | |
| i += 4 * GROUPING_SIZE; | |
| const short vi0 = vui.x & 0xF; | |
| const short vi1 = vui.x >> 4; | |
| float v0 = (vi0 - 8) * d.x; | |
| float v1 = (vi1 - 8) * d.x; | |
| acc += v0 * y_value[0]; | |
| acc += v1 * y_value[1]; | |
| const short vi2 = vui.y & 0xF; | |
| const short vi3 = vui.y >> 4; | |
| float v2 = (vi2 - 8) * d.x; | |
| float v3 = (vi3 - 8) * d.x; | |
| acc += v2 * y_value[2]; | |
| acc += v3 * y_value[3]; | |
| } | |
| { | |
| half4 y_value = *(device half4*)(y + i + 4 * local_id); | |
| i += 4 * GROUPING_SIZE; | |
| const short vi0 = vui.z & 0xF; | |
| const short vi1 = vui.z >> 4; | |
| float v0 = (vi0 - 8) * d.y; | |
| float v1 = (vi1 - 8) * d.y; | |
| acc += v0 * y_value[0]; | |
| acc += v1 * y_value[1]; | |
| const short vi2 = vui.w & 0xF; | |
| const short vi3 = vui.w >> 4; | |
| float v2 = (vi2 - 8) * d.y; | |
| float v3 = (vi3 - 8) * d.y; | |
| acc += v2 * y_value[2]; | |
| acc += v3 * y_value[3]; | |
| } | |
| } | |
| acc += quad_shuffle_xor(acc, 1); | |
| acc += quad_shuffle_xor(acc, 2); | |
| acc += simd_shuffle_xor(acc, 4); | |
| dst[row] = acc; | |
| #undef WEIGHTS_PER_UINT | |
| #undef GROUPING_SIZE | |
| } | |
| """ | |
| } |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment