Skip to content

Instantly share code, notes, and snippets.

@philipturner
Last active June 7, 2023 19:05
Show Gist options
  • Select an option

  • Save philipturner/69d87fa089e43a7a1cda6627c0f388ec to your computer and use it in GitHub Desktop.

Select an option

Save philipturner/69d87fa089e43a7a1cda6627c0f388ec to your computer and use it in GitHub Desktop.
Simulate the bandwidth achieved while executing feedforward layers in LLaMA
import Metal
// MARK: - Usage
// Usage:
//
// 1) Install Xcode from the Mac App Store
//
// 2) From the command line:
//
// touch GEMV_bandwidth.swift
//
// 3) Copy the contents of this script into GEMV_bandwidth.swift
//
// 4) Choose whether to profile LLaMA.cpp's algorithm or the optimized GEMV
//
let feedforwardAlgorithm: Algorithm = {
guard CommandLine.arguments.count == 2,
CommandLine.arguments[1] == "--llama-cpp" ||
CommandLine.arguments[1] == "--llama-cpp-new" ||
CommandLine.arguments[1] == "--gemv-quantized-i4" else {
print("""
Usage: swift GEMV_bandwidth.swift [--llama-cpp | --llama-cpp-new | \
--gemv-quantized-i4]
""")
exit(0)
}
switch CommandLine.arguments[1] {
case "--llama-cpp": return Algorithm.LLaMA_cpp
case "--llama-cpp-new": return Algorithm.LLaMA_cpp_new
case "--gemv-quantized-i4": return Algorithm.GEMV_quantized_i4
default: fatalError("This should never happen.")
}
}()
if CommandLine.arguments[1] == "--gemv-quantized-i4" {
print("""
\u{1b}[0;36mWARNING:\u{1b}[0m \u{1b}[0;32m--gemv-quantized-i4\u{1b}[0m has \
a bug. Do not use it to estimate maximum achievable bandwidth or minimum \
achievable latency.
""")
}
//
// 5) From the command line (the last kernel has a bug):
//
// swift GEMV_bandwidth.swift --llama-cpp
// swift GEMV_bandwidth.swift --llama-cpp-new
// swift GEMV_bandwidth.swift --gemv-quantized-i4
//
// 6) Look for the bandwidth in GB/s
//
// % swift GEMV_bandwidth.swift --gemv-quantized-i4
// 316.4 GB/s
// ≥11.4 ms/token
//
// 7) Repeat steps 5 - 6 three times, and report every number you get.
//
// % swift GEMV_bandwidth.swift --gemv-quantized-i4
// 316.4 GB/s
// ≥11.4 ms/token
// % swift GEMV_bandwidth.swift --gemv-quantized-i4
// 317.2 GB/s
// ≥11.4 ms/token
// % swift GEMV_bandwidth.swift --gemv-quantized-i4
// 315.6 GB/s
// ≥11.4 ms/token
// MARK: - Shader Parameters
// How many times to sample bandwidth while warming the caches.
let numTrials: Int = 10
// The algorithm used to multiply 4-bit matrices with 16/32-bit vectors.
enum Algorithm {
// slower kernel originally in LLaMA.cpp
case LLaMA_cpp
// kernel currently in the GGML Metal backend
case LLaMA_cpp_new
// faster kernel, ONLY RUNS ON M1/M2/Pro/Max/Ultra
// WARNING: This has a bug, so it overestimates bandwidth.
case GEMV_quantized_i4
var gpuFunctionName: String {
switch self {
case .LLaMA_cpp:
return "dequantize_mul_mat_vec_q4_0"
case .LLaMA_cpp_new:
return "kernel_mul_mat_q4_0_f32"
case .GEMV_quantized_i4:
return "gemv_quantized_i4"
}
}
}
// MARK: - Bandwidth Profiling
try testLLaMA()
func testLLaMA() throws {
let devices = MTLCopyAllDevices()
let device = devices.first(where: { !$0.isLowPower }) ?? devices.first!
let commandQueue = device.makeCommandQueue()!
let library = try device.makeLibrary(source: getShaderSource(), options: nil)
let constants = MTLFunctionConstantValues()
var ncols: UInt32 = 4096
constants.setConstantValue(&ncols, type: .uint, index: 0)
let functionName = feedforwardAlgorithm.gpuFunctionName
var function = try! library.makeFunction(
name: functionName, constantValues: constants)
let w13Pipeline = try! device.makeComputePipelineState(function: function)
ncols = 4096 * 4
constants.setConstantValue(&ncols, type: .uint, index: 0)
function = try! library.makeFunction(
name: functionName, constantValues: constants)
let w2Pipeline = try! device.makeComputePipelineState(function: function)
let matrixElements = 4096 * 4096 * 4
let matrixSize = matrixElements / 2 + (matrixElements / 32) * 2
let vectorSize = 4096 * 4
assert(matrixSize == 4096 * 16384 / 2 * 9 / 8)
// Run all 32 layers in quick succession to find asymptotic maximum bandwidth.
struct Vectors {
var x: MTLBuffer
var w1Val: MTLBuffer // quadruple the vector size
var w3Val: MTLBuffer // quadruple the vector size
var output: MTLBuffer
init(device: MTLDevice, vectorSize: Int) {
x = device.makeBuffer(length: vectorSize)!
w1Val = device.makeBuffer(length: vectorSize * 4)!
w3Val = device.makeBuffer(length: vectorSize * 4)!
output = device.makeBuffer(length: vectorSize)!
}
}
struct Context {
var vectors: Vectors
var w13Pipeline: MTLComputePipelineState
var w2Pipeline: MTLComputePipelineState
}
// Don't do anything with the vectors written back to RAM.
struct FeedForward {
var matrixSize: Int
var weights1: MTLBuffer
var weights2: MTLBuffer
var weights3: MTLBuffer
init(device: MTLDevice, matrixSize: Int) {
self.matrixSize = matrixSize
weights1 = device.makeBuffer(length: matrixSize)!
weights2 = device.makeBuffer(length: matrixSize)!
weights3 = device.makeBuffer(length: matrixSize)!
}
var totalMemory: Int {
weights1.length + weights2.length + weights3.length
}
func encode(encoder: MTLComputeCommandEncoder, context ctx: Context) {
var simdRowStride: Int
var simdsPerGroup: Int
switch feedforwardAlgorithm {
case .LLaMA_cpp:
simdRowStride = 1
simdsPerGroup = 1
case .LLaMA_cpp_new:
simdRowStride = -1
simdsPerGroup = -1
case .GEMV_quantized_i4:
simdRowStride = 4
simdsPerGroup = 4
}
if feedforwardAlgorithm == .LLaMA_cpp_new {
func encodeInteger(_ value: Int, index: Int) {
var valueCopy = value
encoder.setBytes(&valueCopy, length: 8, index: index)
/*
constant int64_t & ne00 [[buffer(3)]],
constant int64_t & ne01 [[buffer(4)]],
constant uint64_t & nb00 [[buffer(5)]],
constant uint64_t & nb01 [[buffer(6)]],
constant uint64_t & nb02 [[buffer(7)]],
constant int64_t & ne10 [[buffer(8)]],
constant int64_t & ne11 [[buffer(9)]],
constant uint64_t & nb10 [[buffer(10)]],
constant uint64_t & nb11 [[buffer(11)]],
constant uint64_t & nb12 [[buffer(12)]],
constant int64_t & ne0 [[buffer(13)]],
constant int64_t & ne1 [[buffer(14)]],
*/
}
var nrows: UInt32 = 4096 * 4
func encodeCommon() {
encodeInteger(-1, index: 4)
encodeInteger(-1, index: 5)
encodeInteger(-1, index: 6)
encodeInteger(-1, index: 7)
encodeInteger(-1, index: 9)
encodeInteger(-1, index: 10)
encodeInteger(-1, index: 11)
encodeInteger(-1, index: 12)
encodeInteger(-1, index: 14)
encoder.dispatchThreadgroups(
MTLSizeMake(Int(nrows), 1, 1),
threadsPerThreadgroup: MTLSizeMake(8, 4, 1))
}
encoder.setComputePipelineState(ctx.w13Pipeline)
encoder.setThreadgroupMemoryLength(32 * 4, index: 0)
encoder.setBuffer(weights1, offset: 0, index: 0)
encoder.setBuffer(ctx.vectors.x, offset: 0, index: 1)
encoder.setBuffer(ctx.vectors.w1Val, offset: 0, index: 2)
encodeInteger(4096, index: 3) // ncols
encodeInteger(4096, index: 8) // ncols
encodeInteger(1, index: 13) // always 1
encodeCommon()
encoder.setComputePipelineState(ctx.w13Pipeline)
encoder.setThreadgroupMemoryLength(32 * 4, index: 0)
encoder.setBuffer(weights3, offset: 0, index: 0)
encoder.setBuffer(ctx.vectors.x, offset: 0, index: 1)
encoder.setBuffer(ctx.vectors.w3Val, offset: 0, index: 2)
encodeInteger(4096, index: 3) // ncols
encodeInteger(4096, index: 8) // ncols
encodeInteger(1, index: 13) // always 1
encodeCommon()
nrows = 4096
encoder.setComputePipelineState(ctx.w2Pipeline)
encoder.setThreadgroupMemoryLength(32 * 4, index: 0)
encoder.setBuffer(weights2, offset: 0, index: 0)
encoder.setBuffer(ctx.vectors.w3Val, offset: 0, index: 1)
encoder.setBuffer(ctx.vectors.output, offset: 0, index: 2)
encodeInteger(4096 * 4, index: 3) // ncols
encodeInteger(4096 * 4, index: 8) // ncols
encodeInteger(1, index: 13) // always 1
encodeCommon()
} else {
encoder.setComputePipelineState(ctx.w13Pipeline)
encoder.setThreadgroupMemoryLength(4 * 32, index: 0)
encoder.setBuffer(weights1, offset: 0, index: 0)
encoder.setBuffer(weights1, offset: matrixSize * 8 / 9, index: 1)
encoder.setBuffer(ctx.vectors.x, offset: 0, index: 2)
encoder.setBuffer(ctx.vectors.w1Val, offset: 0, index: 3)
var ncols: UInt32 = 4096
var nrows: UInt32 = 4096 * 4
encoder.setBytes(&ncols, length: 4, index: 4)
encoder.dispatchThreadgroups(
MTLSizeMake(Int(nrows) / simdRowStride / simdsPerGroup, 1, 1),
threadsPerThreadgroup: MTLSizeMake(32 * simdsPerGroup, 1, 1))
encoder.setComputePipelineState(ctx.w13Pipeline)
encoder.setThreadgroupMemoryLength(4 * 32, index: 0)
encoder.setBuffer(weights3, offset: 0, index: 0)
encoder.setBuffer(weights3, offset: matrixSize * 8 / 9, index: 1)
encoder.setBuffer(ctx.vectors.x, offset: 0, index: 2)
encoder.setBuffer(ctx.vectors.w3Val, offset: 0, index: 3)
encoder.setBytes(&ncols, length: 4, index: 4)
encoder.dispatchThreadgroups(
MTLSizeMake(Int(nrows) / simdRowStride / simdsPerGroup, 1, 1),
threadsPerThreadgroup: MTLSizeMake(32 * simdsPerGroup, 1, 1))
encoder.setComputePipelineState(ctx.w2Pipeline)
encoder.setThreadgroupMemoryLength(4 * 32, index: 0)
encoder.setBuffer(weights2, offset: 0, index: 0)
encoder.setBuffer(weights2, offset: matrixSize * 8 / 9, index: 1)
encoder.setBuffer(ctx.vectors.w3Val, offset: 0, index: 2)
encoder.setBuffer(ctx.vectors.output, offset: 0, index: 3)
ncols = 4096 * 4
nrows = 4096
encoder.setBytes(&ncols, length: 4, index: 4)
encoder.dispatchThreadgroups(
MTLSizeMake(Int(nrows) / simdRowStride / simdsPerGroup, 1, 1),
threadsPerThreadgroup: MTLSizeMake(32 * simdsPerGroup, 1, 1))
}
}
}
let numLayers = 32
let vectors = Vectors(device: device, vectorSize: vectorSize)
let context = Context(
vectors: vectors, w13Pipeline: w13Pipeline, w2Pipeline: w2Pipeline)
var layers: [FeedForward] = []
for _ in 0..<numLayers {
let layer = FeedForward(device: device, matrixSize: matrixSize)
layers.append(layer)
}
var maxBandwidth: Float = 0
for _ in 0..<numTrials {
let commandBuffer = commandQueue.makeCommandBuffer()!
let encoder = commandBuffer.makeComputeCommandEncoder()!
for layer in layers {
layer.encode(encoder: encoder, context: context)
}
encoder.endEncoding()
commandBuffer.commit()
commandBuffer.waitUntilCompleted()
let time = commandBuffer.gpuEndTime - commandBuffer.gpuStartTime
let bytes = Double(numLayers * layers[0].totalMemory)
let bandwidth = bytes / time / 1e9
maxBandwidth = max(maxBandwidth, Float(bandwidth))
}
print("""
\u{1b}[0;33m\(String(format: "%.1f", Float(maxBandwidth))) GB/s\u{1b}[0m
""")
let modelSize = Float(numLayers * layers[0].totalMemory)
let latency = modelSize / (maxBandwidth * 1e9)
let usPerToken = Int(latency / 1e-6)
let ms = usPerToken / 1000
let us = usPerToken % 1000
print("\u{1b}[0;33m≥\(ms).\(us / 100) ms/token\u{1b}[0m")
}
// Bypass a bug in the Swift interpreter, where global variables declared later
// in the file are initialized to empty or zero.
func getShaderSource() -> String {
return """
#define COMPILE_LEGACY_LLAMA_CPP_OPENCL_SHADER 1
#define COMPILE_NEW_LLAMA_CPP_METAL_SHADER 1
#include <metal_stdlib>
using namespace metal;
// Perform a feedforward layer of LLaMA-6.7B. Ensure you are cycling through
// 32 instances of the layer weights, otherwise they will fall into the
// system-level cache.
//
// Inference from both the LLaMA.cpp format and a bandwidth-optimized format.
#if COMPILE_LEGACY_LLAMA_CPP_OPENCL_SHADER
// Reference implementation from LLaMA.cpp. The custom implementation stores
// weights in a different format.
#define QK4_0 32
#define QR4_0 2
struct __attribute__ ((packed)) __old_block_q4_0
{
half d;
uint8_t qs[QK4_0 / 2];
};
void dequantize_q4_0
(
const device __old_block_q4_0* x,
const int ib,
const int iqs,
thread float* v0,
thread float* v1)
{
const float d = float(x[ib].d);
const uint8_t vui = x[ib].qs[iqs];
const int8_t vi0 = vui & 0xF;
const int8_t vi1 = vui >> 4;
*v0 = (vi0 - 8)*d;
*v1 = (vi1 - 8)*d;
}
// Original:
// 88.5 GB/s
// Don't read out-of-bounds vector data:
// 96.0 GB/s
kernel void dequantize_mul_mat_vec_q4_0
(
device __old_block_q4_0* x [[buffer(0)]],
threadgroup float* tmp [[threadgroup(0)]],
device float* y [[buffer(2)]],
device float* dst [[buffer(3)]],
constant uint &ncols [[buffer(4)]],
uint block_size [[threads_per_threadgroup]],
uint global_id [[thread_position_in_grid]],
uint local_id [[thread_position_in_threadgroup]])
{
const uint row = global_id / block_size;
const uint qk = QK4_0;
const uint qr = QR4_0;
const int y_offset = qr == 1 ? 1 : qk/2;
tmp[local_id] = 0;
for (uint i = 0; i < ncols/block_size; i += 2) {
const uint col = i*block_size + 2*local_id;
const uint ib = (row*ncols + col)/qk; // block index
const uint iqs = (col%qk)/qr; // quant index
const uint iybs = col - col%qk; // y block start index
// dequantize
float v0, v1;
dequantize_q4_0(x, ib, iqs, &v0, &v1);
// matrix multiplication
tmp[local_id] += v0 * y[iybs + iqs + 0];
tmp[local_id] += v1 * y[iybs + iqs + y_offset];
}
// sum up partial sums and write back result
threadgroup_barrier(mem_flags::mem_threadgroup);
for (uint s=block_size/2; s>0; s>>=1) {
if (local_id < s) {
tmp[local_id] += tmp[local_id + s];
}
threadgroup_barrier(mem_flags::mem_threadgroup);
}
if (local_id == 0) {
dst[row] = tmp[0];
}
}
#undef QK4_0
#undef QKR_0
#endif
#if COMPILE_NEW_LLAMA_CPP_METAL_SHADER
#define QK4_0 32
#define QR4_0 2
typedef struct {
half d; // delta
uint8_t qs[QK4_0 / 2]; // nibbles / quants
} block_q4_0;
kernel void kernel_mul_mat_q4_0_f32(
device const void * src0 [[buffer(0)]],
device const float * src1 [[buffer(1)]],
device float * dst [[buffer(2)]],
constant int64_t & ne00 [[buffer(3)]],
constant int64_t & ne01 [[buffer(4)]],
constant uint64_t & nb00 [[buffer(5)]],
constant uint64_t & nb01 [[buffer(6)]],
constant uint64_t & nb02 [[buffer(7)]],
constant int64_t & ne10 [[buffer(8)]],
constant int64_t & ne11 [[buffer(9)]],
constant uint64_t & nb10 [[buffer(10)]],
constant uint64_t & nb11 [[buffer(11)]],
constant uint64_t & nb12 [[buffer(12)]],
constant int64_t & ne0 [[buffer(13)]],
constant int64_t & ne1 [[buffer(14)]],
threadgroup float * sum [[threadgroup(0)]],
uint2 tgpig[[threadgroup_position_in_grid]],
uint2 tpig[[thread_position_in_grid]],
uint2 tpitg[[thread_position_in_threadgroup]],
uint2 tptg[[threads_per_threadgroup]]) {
const int nb = ne00/QK4_0;
const int64_t r0 = tgpig.x;
const int64_t r1 = tgpig.y;
device const block_q4_0 * x = (device const block_q4_0 *) src0 + r0*nb;
device const float * y = (device const float *) src1 + r1*ne10;
const uint nth = tptg.x*tptg.y;
const uint ith = tptg.y*tpitg.x + tpitg.y;
sum[ith] = 0.0f;
for (int i = tpitg.x; i < nb; i += tptg.x) {
device const uchar4 * x0p = (device const uchar4 *) (x + i)->qs;
device const float4 * y0p = (device const float4 *) (y + i*QK4_0);
const float d = (float)((x + i)->d);
const uchar4 x0v = *(x0p + tpitg.y);
const float4 y0v = *(y0p + tpitg.y + 0);
const float4 y1v = *(y0p + tpitg.y + 4);
float acc = 0.0f;
for (int j = 0; j < 4; ++j) {
const int x0 = x0v[j] & 0x0F;
const int x1 = x0v[j] >> 4;
const float y0 = y0v[j];
const float y1 = y1v[j];
acc += (x0 - 8)*y0 + (x1 - 8)*y1;
}
sum[ith] += acc*d;
}
// accumulate the sum from all threads in the threadgroup
threadgroup_barrier(mem_flags::mem_threadgroup);
for (uint i = nth/2; i > 0; i /= 2) {
if (ith < i) {
sum[ith] += sum[ith + i];
}
threadgroup_barrier(mem_flags::mem_threadgroup);
}
if (ith == 0) {
dst[r1*ne0 + r0] = sum[0];
}
}
#undef QK4_0
#undef QKR_0
#endif
// Original:
// 88.5 GB/s
// Switch from threadgroup to simdgroup sum:
// 90.2 GB/s
// Deinterleave the weights and scales:
// 98.5 GB/s
// Hard-code shader parameters and remove `if (local_id == 0)`:
// 139.4 GB/s
// Directly index `uint8_t` instead of a struct:
// 143.6 GB/s
// Don't read out-of-bounds vector data:
// 172.6 GB/s
// Read input vectors as half-precision:
// 193.8 GB/s
// Coalesce the accesses to y and read the correct value from `weights`:
// 199.4 GB/s
// Change threadgroup size from 32 to 64:
// 210.4 GB/s
// Change threadgroup size from 32 to 128:
// 211.7 GB/s
// Two rows per simd:
// 225.4 GB/s
// Four rows per simd:
// 226.7 GB/s
// Unroll two iterations of the loop / un-duplicate scale reads:
// 243.3 GB/s
// Coalesce two Y reads:
// 245.7 GB/s
// Perform both X reads at the same time:
// 253.9 GB/s
// Unroll four iterations of the loop:
// (BAD DATA) 292.4 GB/s
// Coalesce X reads:
// (BAD DATA) 353.6 GB/s
// Coalesce four Y reads and use the correct index within a row:
// (BAD DATA) 374.8 GB/s
// Change how the buffers are indexed:
// (BAD DATA) 406.9 GB/s
// Use the correct value for 'i':
// 304.0 GB/s
// Optimize how 'vui' is stored in registers:
// 306.1 GB/s
// Optimize the generation of the index for scales:
// 319.3 GB/s
constant uint ncols [[function_constant(0)]];
// SIMD shuffle instructions require Metal 3 support, although the
// function itself is heavily optimized for the Apple 7 architecture.
//
// TODO: Support different quantization formats through function constants.
kernel void gemv_quantized_i4
(
device uchar4 *weights [[buffer(0)]],
device half2 *scales [[buffer(1)]],
device half *y [[buffer(2)]],
device half *dst [[buffer(3)]],
uint tid [[thread_position_in_grid]])
{
// 8-wide groupings of threads, each thread reads 8 values per iteration.
// 'groupings of threads' != 'threadgroups'
#define WEIGHTS_PER_UINT 8
#define GROUPING_SIZE 8
uint row = tid / GROUPING_SIZE;
ushort local_id = tid % GROUPING_SIZE;
float acc = 0;
// Changing this to a `while` loop harms performance. Perhaps it triggers a
// separate assembly instruction for control flow.
for (uint i = 0; i < ncols;) {
uchar4 vui = weights[i / WEIGHTS_PER_UINT + local_id];
const uint blocks_in_row = ncols / 32;
half2 d = scales[row * blocks_in_row / 2 + i / 32 / 2];
{
half4 y_value = *(device half4*)(y + i + 4 * local_id);
i += 4 * GROUPING_SIZE;
const short vi0 = vui.x & 0xF;
const short vi1 = vui.x >> 4;
float v0 = (vi0 - 8) * d.x;
float v1 = (vi1 - 8) * d.x;
acc += v0 * y_value[0];
acc += v1 * y_value[1];
const short vi2 = vui.y & 0xF;
const short vi3 = vui.y >> 4;
float v2 = (vi2 - 8) * d.x;
float v3 = (vi3 - 8) * d.x;
acc += v2 * y_value[2];
acc += v3 * y_value[3];
}
{
half4 y_value = *(device half4*)(y + i + 4 * local_id);
i += 4 * GROUPING_SIZE;
const short vi0 = vui.z & 0xF;
const short vi1 = vui.z >> 4;
float v0 = (vi0 - 8) * d.y;
float v1 = (vi1 - 8) * d.y;
acc += v0 * y_value[0];
acc += v1 * y_value[1];
const short vi2 = vui.w & 0xF;
const short vi3 = vui.w >> 4;
float v2 = (vi2 - 8) * d.y;
float v3 = (vi3 - 8) * d.y;
acc += v2 * y_value[2];
acc += v3 * y_value[3];
}
}
acc += quad_shuffle_xor(acc, 1);
acc += quad_shuffle_xor(acc, 2);
acc += simd_shuffle_xor(acc, 4);
dst[row] = acc;
#undef WEIGHTS_PER_UINT
#undef GROUPING_SIZE
}
"""
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment