Skip to content

Instantly share code, notes, and snippets.

@malfet
malfet / pytorch-perf.md
Last active January 15, 2025 17:06
PyTorch LLM perf
dtype SOTA 2.2.2+eager 2.3.0+eager 2.3.0+compile trunk + compile
bfloat16 (M1) 111 tokens/sec 110 tokens/sec 80 tokens/sec
float32 (M1) 687 tokens/sec 165 tokens/sec 176 tokens/sec
float16 (M1) 1106 tokens/sec 50 tokens/sec 187 tokens/sec
float16 (LinX86) 40 tokens/sec 43 tokens/sec 173 tokens/sec
float32 (LinX86) 38 tokens/sec 40 tokens/sec 179 tokens/sec
bfloat16 (LinX86) 73 tokens/sec 78 tokens/sec 180 tokens/sec
@malfet
malfet / mps_matmul.swift
Created January 9, 2024 02:29
Swift example that runs matrix multiplicaiton on MPS
import MetalPerformanceShadersGraph
let graph = MPSGraph()
let x = graph.constant(1, shape: [32, 4096, 40], dataType: .float32)
let y = graph.constant(1, shape: [32, 40, 4096], dataType: .float32)
let z = graph.matrixMultiplication(primary: x, secondary: y, name: nil)
let device = MTLCreateSystemDefaultDevice()!
let buf = device.makeBuffer(length: 16384)!
let td = MPSGraphTensorData(buf, shape: [64, 64], dataType: .int32)
let cmdBuf = MPSCommandBuffer(from: device.makeCommandQueue()!)
@malfet
malfet / mm_bmm-perf.py
Last active February 16, 2024 00:27
Measure performance difference of `torch.mm` vs `torch.bmm`
# Benchmark relative performance of torch.mm and torch.bmm with single batch
import torch
import time
def benchmark_fn(fn, args, warmup=5, cycles=300, use_kineto=False) -> float:
if use_kineto:
with torch.profiler.profile(activities=[torch.profiler.ProfilerActivity.CUDA]) as p:
fn(*args)
return sum([e.cuda_time for e in p.key_averages()])
@malfet
malfet / test_trition.py
Last active January 6, 2024 19:47
Test triton
import triton
import triton.language as tl
@triton.jit
def kernel(in_ptr0, out_ptr0, xnumel, XBLOCK: tl.constexpr):
xnumel = 10
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:]
xmask = xindex < xnumel
x0 = xindex
import torch
import torch.nn.functional as F
def to_float8(x, dtype=torch.float8_e4m3fn):
finfo = torch.finfo(dtype)
# Calculate the scale as dtype max divided by absmax
scale = finfo.max / x.abs().max().clamp(min=1e-12)
# scale and clamp the tensor to bring it to
# the representative range of float8 data type
# (as default cast is unsaturated)
// My attempt at FP8 matmul implementation
#include <iostream>
#include <vector>
#include <numeric>
#include <cublasLt.h>
#include <cuda_fp8.h>
#include <stdio.h>
@malfet
malfet / slowregexp.py
Last active June 29, 2023 15:29
Catastrophic backtracking in regexp
# For some reason does not work when copied-an-pasted not as raw file, but otherwise shoudl hang
import re
pat=re.compile('\\.\\. (code-block|math)::.*$\\n*(?P<S2VCUH>(?P<first>(^(?P<indent>[ ]+).*$\\n))(?P<other>(^([ \\t]+.*|[ \\t]*)$\\n)*))(?:(^(?![ \\t]+.*$))|\\Z)', re.MULTILINE)
text="""#####################################################################
We get the following performance profiling table for the eager-mode model (omitting some columns):
.. code-block:: shell
------------------------- ------------ ------------ ------------ ------------
Name CPU total % CPU total CPU time avg # of Calls
@malfet
malfet / computesqrt.py
Last active March 31, 2023 01:50
Spigot algorithm for computing digits of square root
#!/usr/bin/env python3
# Adapted from https://rosettacode.org/wiki/Square_root_by_hand
def next_digit(val, k):
for d in range(1, 11):
if val < d * (k + d):
return d - 1
raise RuntimeError("Impossible")
def compute_sqrt(val=2, num_char=500):
@malfet
malfet / query_utilization.py
Created March 21, 2023 16:20
Query A100 AWS nodes CPU utilization
import boto3
import pandas as pd
from datetime import datetime, timedelta
from typing import Optional
cloudwatch = boto3.client("cloudwatch")
ec2 = boto3.resource("ec2")
def ec2_get_instances(filter_name, filter_value):
return ec2.instances.filter(Filters=[{'Name': filter_name,
@malfet
malfet / gist:e098ad49ecde484105b5efc7f50db644
Created February 14, 2023 23:18
Use openai-whisper on CPU vs MPS
Python 3.10.8 (main, Nov 24 2022, 08:08:27) [Clang 14.0.6 ] on darwin
Type "help", "copyright", "credits" or "license" for more information.
>>> import torch
>>> import whisper
>>> torch.__version__
'2.0.0a0+git01de5dd'
>>> model = whisper.load_model("base")
>>> audio = whisper.load_audio("c1.mp3") # downloaded from https://www.mobydickbigread.com/chapter-1-loomings/
>>> audio = whisper.pad_or_trim(audio)
>>> model.transcribe(audio)["text"]