Skip to content

Instantly share code, notes, and snippets.

@psinger
Created February 11, 2025 16:17
Show Gist options
  • Save psinger/4c0be78770d1b84d641e9dab2208c9b0 to your computer and use it in GitHub Desktop.
Save psinger/4c0be78770d1b84d641e9dab2208c9b0 to your computer and use it in GitHub Desktop.
import torch
import time
from bitsandbytes.nn import Params4bit
from torchao.dtypes.nf4tensor import NF4Tensor, linear_nf4, to_nf4
import bitsandbytes as bnb
import pandas as pd
def benchmark_forward_pass(batch_size, hidden_dim, method='nf4tensor', num_iters=100):
"""Run forward pass benchmark for given parameters."""
device = torch.device('cuda:0')
x = torch.randn(batch_size, hidden_dim, device=device).to(torch.bfloat16)
weight = torch.randn(hidden_dim, hidden_dim, device=device).to(torch.bfloat16)
if method == 'nf4tensor':
# NF4Tensor approach
weight_nf4 = to_nf4(weight, block_size=64, scaler_block_size=256)
# Warmup
for _ in range(10):
_ = linear_nf4(input=x, weight=weight_nf4)
torch.cuda.synchronize()
start_time = time.time()
for _ in range(num_iters):
out = linear_nf4(input=x, weight=weight_nf4)
torch.cuda.synchronize()
end_time = time.time()
else: # params4bit
# Params4bit approach
weight_4bit = Params4bit(
data=weight,
quant_type="nf4",
)._quantize(device=device)
# Warmup
for _ in range(10):
_ = bnb.matmul_4bit(
x,
weight_4bit.t(),
bias=None,
quant_state=weight_4bit.quant_state,
).to(torch.bfloat16) # Ensure output is bfloat16
torch.cuda.synchronize()
start_time = time.time()
for _ in range(num_iters):
out = bnb.matmul_4bit(
x,
weight_4bit.t(),
bias=None,
quant_state=weight_4bit.quant_state,
).to(torch.bfloat16) # Ensure output is bfloat16
torch.cuda.synchronize()
end_time = time.time()
return (end_time - start_time) / num_iters # Average time per iteration
def run_benchmarks():
batch_sizes = [1, 8, 32, 128]
hidden_dims = [1024, 2048, 4096, 8192]
num_outer_loops = 10 # Number of times to repeat each benchmark
results = []
print("\n" + "="*80)
print(f"{'Batch Size':^15} {'Hidden Dim':^15} {'Method':^15} {'Time (ms)':^15} {'Memory (MB)':^15}")
print("="*80)
for batch_size in batch_sizes:
for hidden_dim in hidden_dims:
# Lists to store times for averaging
nf4_times = []
p4_times = []
nf4_mems = []
p4_mems = []
for i in range(num_outer_loops):
# Clear cache before each test
torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats()
# NF4Tensor benchmark
time_nf4tensor = benchmark_forward_pass(batch_size, hidden_dim, 'nf4tensor')
mem_nf4 = torch.cuda.max_memory_allocated() / 1024**2
nf4_times.append(time_nf4tensor)
nf4_mems.append(mem_nf4)
# Clear cache before next test
torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats()
# Params4bit benchmark
time_params4bit = benchmark_forward_pass(batch_size, hidden_dim, 'params4bit')
mem_p4 = torch.cuda.max_memory_allocated() / 1024**2
p4_times.append(time_params4bit)
p4_mems.append(mem_p4)
if i % 10 == 0: # Print progress every 10 iterations
print(f"Progress: {i}/{num_outer_loops} iterations", end='\r')
# Calculate averages and standard deviations
avg_nf4_time = sum(nf4_times) / num_outer_loops
avg_p4_time = sum(p4_times) / num_outer_loops
avg_nf4_mem = sum(nf4_mems) / num_outer_loops
avg_p4_mem = sum(p4_mems) / num_outer_loops
std_nf4_time = torch.tensor(nf4_times).std().item()
std_p4_time = torch.tensor(p4_times).std().item()
# Store results
results.append({
'batch_size': batch_size,
'hidden_dim': hidden_dim,
'method': 'NF4Tensor',
'time_ms': avg_nf4_time * 1000,
'time_std_ms': std_nf4_time * 1000,
'memory_mb': avg_nf4_mem
})
results.append({
'batch_size': batch_size,
'hidden_dim': hidden_dim,
'method': 'Params4bit',
'time_ms': avg_p4_time * 1000,
'time_std_ms': std_p4_time * 1000,
'memory_mb': avg_p4_mem
})
# Print results for this configuration
print(f"\nResults for batch_size={batch_size}, hidden_dim={hidden_dim}")
print(f"NF4Tensor: {avg_nf4_time*1000:>8.2f} ± {std_nf4_time*1000:>5.2f} ms, {avg_nf4_mem:>8.1f} MB")
print(f"Params4bit: {avg_p4_time*1000:>8.2f} ± {std_p4_time*1000:>5.2f} ms, {avg_p4_mem:>8.1f} MB")
print("-"*80)
return pd.DataFrame(results)
if __name__ == "__main__":
print("Starting NF4 benchmarking...")
results_df = run_benchmarks()
print("\nSummary Statistics:")
print("="*80)
summary = results_df.pivot_table(
index=['batch_size', 'hidden_dim'],
columns='method',
values=['time_ms', 'time_std_ms', 'memory_mb'],
aggfunc='mean'
)
print(summary)
# Calculate and display speedup ratios with standard errors
print("\nSpeedup Ratios (Params4bit time / NF4Tensor time):")
print("="*80)
speedup = results_df.pivot_table(
index=['batch_size', 'hidden_dim'],
values=['time_ms', 'time_std_ms'],
columns='method'
)
speedup['ratio'] = speedup[('time_ms', 'Params4bit')] / speedup[('time_ms', 'NF4Tensor')]
print(speedup['ratio'])
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment