Skip to content

Instantly share code, notes, and snippets.

View geohot's full-sized avatar

George Hotz geohot

View GitHub Profile
@geohot
geohot / hip.py
Created November 25, 2023 23:28
Wrapper for HIP
# -*- coding: utf-8 -*-
#
# TARGET arch is: ['-D__HIP_PLATFORM_AMD__', '-I/opt/rocm/include']
# WORD_SIZE is: 8
# POINTER_SIZE is: 8
# LONGDOUBLE_SIZE is: 16
#
import ctypes
@geohot
geohot / memcpy.py
Created November 21, 2023 19:21
Fast memcpy using GPUs
# tiny@tiny9:~/tinygrad$ python3 examples/benchmark_copies.py
# CPU copy 6.18 ms, 16.28 GB/s
# GPU copy 4.38 ms, 23.00 GB/s
# GPU 6x 1.85 ms, 54.54 GB/s
import time
def timeit(fxn):
tms = []
for _ in range(10):
st = time.perf_counter()
@geohot
geohot / matmul.cl
Last active April 16, 2025 16:04
A 1024x1024x1024 matmul with a 2x2x2 core in OpenCL
__kernel void matmul(__global float* data0, const __global float* data1, const __global float* data2) {
int gidx0 = get_group_id(1); /* 512 */
int gidx1 = get_group_id(0); /* 512 */
float2 acc0 = (float2)(0.0f,0.0f);
float2 acc1 = (float2)(0.0f,0.0f);
for (int ridx0 = 0; ridx0 < 512; ++ridx0) {
float2 val0 = (float2)(*((__global float2*)(data1+(gidx0*2048)+(ridx0*2))));
float2 val1 = (float2)(*((__global float2*)(data1+(gidx0*2048)+(ridx0*2)+1024)));
float2 val2 = (float2)(*((__global float2*)(data2+(gidx1*2)+(ridx0*2048))));
float2 val3 = (float2)(*((__global float2*)(data2+(gidx1*2)+(ridx0*2048)+1024)));
@geohot
geohot / cifar_wino_kernels
Created October 19, 2023 05:48
kernels for BS=1024 CIFAR BEAM=2 WINO=1
*** 0 E_64_32_6_6n5 arg 2 sz [64, 1, 1] [32, 1, 1] OPs 33M/ 0.00G mem 3.07 GB tm 3.20us/ 0.00ms (10483.20 GFLOPS, 297.02 GB/s)
*** 1 r_128_31_31_3_2_3_2_2_2_8n26 arg 3 sz [31, 31, 128] [2, 3, 1] OPs 283M/ 0.03G mem 3.07 GB tm 218.44us/ 0.22ms ( 1297.42 GFLOPS, 216.24 GB/s)
*** 2 r_1024_32_16_2_3_4_4_8n6 arg 3 sz [32, 1024, 1] [2, 16, 1] OPs 805M/ 0.32G mem 3.07 GB tm 64.68us/ 0.29ms (12450.43 GFLOPS, 1426.62 GB/s)
@geohot
geohot / test_allreduce.py
Created March 14, 2023 07:19
Test Bandwidth of all reduce
import os
import sys
import time
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
def all_reduce_latency(nbytes, rank):
buf = torch.randn(nbytes // 4).cuda(rank)
@geohot
geohot / mod_range.py
Last active March 4, 2023 06:33
mod of a range
# given a number in the range [amin, amax] (inclusive)
# what are the min and max of that number after modding it by b?
# c style modulus
def modn(a, b): return -((-a)%b) if a < 0 else a%b
# aka a fast version of
def slow_modrange(amin, amax, b):
values = [modn(rv, b) for rv in range(amin, amax+1)]
return min(values), max(values)
import torch
torch.set_grad_enabled(False)
model = torch.nn.Linear(1, 1, bias=False).cuda()
model.weight[:] = 1.
print(model(torch.Tensor([2349.]).cuda()))
@geohot
geohot / gist:569e9e4b20fd41203d8da71c6022be15
Last active April 30, 2024 21:39
instructions to install openpilot on a pixel 3 running android 9
# instructions to install openpilot on a pixel 3
# enter fastboot with power + volume down
# make sure bootloader is unlocked
# make sure modern version of android platform tools is installed
mkdir pixel
wget https://dl.google.com/dl/android/aosp/blueline-pq3a.190801.002-factory-f3d66c49.zip
unzip blueline-pq3a.190801.002-factory-f3d66c49.zip
cd blueline-pq3a.190801.002/
./flash-all.sh
@geohot
geohot / prius_kf.py
Last active March 9, 2021 07:36
Prius Steering Angle Kalman Filter
%pylab inline
%load_ext autoreload
%autoreload 2
from tools.lib.route import Route
from tools.lib.logreader import LogReader
r,num = Route("ce2fbd370f78ef21|2020-11-27--16-27-28"),10
#r,num = Route("f66032c2b5aa18ac|2020-12-04--09-33-54"),30
alr = []
for n in range(num-1, num+5):
@geohot
geohot / clang_fore.diff
Created July 30, 2020 01:08
Add support for "fore" loops to clang
diff --git a/clang/include/clang/AST/Stmt.h b/clang/include/clang/AST/Stmt.h
index 13f265223..61b0a83c6 100644
--- a/clang/include/clang/AST/Stmt.h
+++ b/clang/include/clang/AST/Stmt.h
@@ -2459,13 +2459,16 @@ class ForStmt : public Stmt {
public:
ForStmt(const ASTContext &C, Stmt *Init, Expr *Cond, VarDecl *condVar,
Expr *Inc, Stmt *Body, SourceLocation FL, SourceLocation LP,
- SourceLocation RP);
+ SourceLocation RP, bool is_fore_statement=false);