Created
February 11, 2025 16:17
-
-
Save psinger/4c0be78770d1b84d641e9dab2208c9b0 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import time | |
from bitsandbytes.nn import Params4bit | |
from torchao.dtypes.nf4tensor import NF4Tensor, linear_nf4, to_nf4 | |
import bitsandbytes as bnb | |
import pandas as pd | |
def benchmark_forward_pass(batch_size, hidden_dim, method='nf4tensor', num_iters=100): | |
"""Run forward pass benchmark for given parameters.""" | |
device = torch.device('cuda:0') | |
x = torch.randn(batch_size, hidden_dim, device=device).to(torch.bfloat16) | |
weight = torch.randn(hidden_dim, hidden_dim, device=device).to(torch.bfloat16) | |
if method == 'nf4tensor': | |
# NF4Tensor approach | |
weight_nf4 = to_nf4(weight, block_size=64, scaler_block_size=256) | |
# Warmup | |
for _ in range(10): | |
_ = linear_nf4(input=x, weight=weight_nf4) | |
torch.cuda.synchronize() | |
start_time = time.time() | |
for _ in range(num_iters): | |
out = linear_nf4(input=x, weight=weight_nf4) | |
torch.cuda.synchronize() | |
end_time = time.time() | |
else: # params4bit | |
# Params4bit approach | |
weight_4bit = Params4bit( | |
data=weight, | |
quant_type="nf4", | |
)._quantize(device=device) | |
# Warmup | |
for _ in range(10): | |
_ = bnb.matmul_4bit( | |
x, | |
weight_4bit.t(), | |
bias=None, | |
quant_state=weight_4bit.quant_state, | |
).to(torch.bfloat16) # Ensure output is bfloat16 | |
torch.cuda.synchronize() | |
start_time = time.time() | |
for _ in range(num_iters): | |
out = bnb.matmul_4bit( | |
x, | |
weight_4bit.t(), | |
bias=None, | |
quant_state=weight_4bit.quant_state, | |
).to(torch.bfloat16) # Ensure output is bfloat16 | |
torch.cuda.synchronize() | |
end_time = time.time() | |
return (end_time - start_time) / num_iters # Average time per iteration | |
def run_benchmarks(): | |
batch_sizes = [1, 8, 32, 128] | |
hidden_dims = [1024, 2048, 4096, 8192] | |
num_outer_loops = 10 # Number of times to repeat each benchmark | |
results = [] | |
print("\n" + "="*80) | |
print(f"{'Batch Size':^15} {'Hidden Dim':^15} {'Method':^15} {'Time (ms)':^15} {'Memory (MB)':^15}") | |
print("="*80) | |
for batch_size in batch_sizes: | |
for hidden_dim in hidden_dims: | |
# Lists to store times for averaging | |
nf4_times = [] | |
p4_times = [] | |
nf4_mems = [] | |
p4_mems = [] | |
for i in range(num_outer_loops): | |
# Clear cache before each test | |
torch.cuda.empty_cache() | |
torch.cuda.reset_peak_memory_stats() | |
# NF4Tensor benchmark | |
time_nf4tensor = benchmark_forward_pass(batch_size, hidden_dim, 'nf4tensor') | |
mem_nf4 = torch.cuda.max_memory_allocated() / 1024**2 | |
nf4_times.append(time_nf4tensor) | |
nf4_mems.append(mem_nf4) | |
# Clear cache before next test | |
torch.cuda.empty_cache() | |
torch.cuda.reset_peak_memory_stats() | |
# Params4bit benchmark | |
time_params4bit = benchmark_forward_pass(batch_size, hidden_dim, 'params4bit') | |
mem_p4 = torch.cuda.max_memory_allocated() / 1024**2 | |
p4_times.append(time_params4bit) | |
p4_mems.append(mem_p4) | |
if i % 10 == 0: # Print progress every 10 iterations | |
print(f"Progress: {i}/{num_outer_loops} iterations", end='\r') | |
# Calculate averages and standard deviations | |
avg_nf4_time = sum(nf4_times) / num_outer_loops | |
avg_p4_time = sum(p4_times) / num_outer_loops | |
avg_nf4_mem = sum(nf4_mems) / num_outer_loops | |
avg_p4_mem = sum(p4_mems) / num_outer_loops | |
std_nf4_time = torch.tensor(nf4_times).std().item() | |
std_p4_time = torch.tensor(p4_times).std().item() | |
# Store results | |
results.append({ | |
'batch_size': batch_size, | |
'hidden_dim': hidden_dim, | |
'method': 'NF4Tensor', | |
'time_ms': avg_nf4_time * 1000, | |
'time_std_ms': std_nf4_time * 1000, | |
'memory_mb': avg_nf4_mem | |
}) | |
results.append({ | |
'batch_size': batch_size, | |
'hidden_dim': hidden_dim, | |
'method': 'Params4bit', | |
'time_ms': avg_p4_time * 1000, | |
'time_std_ms': std_p4_time * 1000, | |
'memory_mb': avg_p4_mem | |
}) | |
# Print results for this configuration | |
print(f"\nResults for batch_size={batch_size}, hidden_dim={hidden_dim}") | |
print(f"NF4Tensor: {avg_nf4_time*1000:>8.2f} ± {std_nf4_time*1000:>5.2f} ms, {avg_nf4_mem:>8.1f} MB") | |
print(f"Params4bit: {avg_p4_time*1000:>8.2f} ± {std_p4_time*1000:>5.2f} ms, {avg_p4_mem:>8.1f} MB") | |
print("-"*80) | |
return pd.DataFrame(results) | |
if __name__ == "__main__": | |
print("Starting NF4 benchmarking...") | |
results_df = run_benchmarks() | |
print("\nSummary Statistics:") | |
print("="*80) | |
summary = results_df.pivot_table( | |
index=['batch_size', 'hidden_dim'], | |
columns='method', | |
values=['time_ms', 'time_std_ms', 'memory_mb'], | |
aggfunc='mean' | |
) | |
print(summary) | |
# Calculate and display speedup ratios with standard errors | |
print("\nSpeedup Ratios (Params4bit time / NF4Tensor time):") | |
print("="*80) | |
speedup = results_df.pivot_table( | |
index=['batch_size', 'hidden_dim'], | |
values=['time_ms', 'time_std_ms'], | |
columns='method' | |
) | |
speedup['ratio'] = speedup[('time_ms', 'Params4bit')] / speedup[('time_ms', 'NF4Tensor')] | |
print(speedup['ratio']) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment