Skip to content

Instantly share code, notes, and snippets.

@jrruethe
Created August 1, 2024 18:47
Show Gist options
  • Save jrruethe/8974d2c8b4ece242a071d1a1526aa763 to your computer and use it in GitHub Desktop.
Save jrruethe/8974d2c8b4ece242a071d1a1526aa763 to your computer and use it in GitHub Desktop.
Calculate VRAM requirements for LLM models
#!/usr/bin/env ruby
# https://asmirnov.xyz/vram
# https://vram.asmirnov.xyz
require "fileutils"
require "json"
require "open-uri"
# https://huggingface.co/spaces/NyxKrage/LLM-Model-VRAM-Calculator/blob/main/index.html
GGUF_MAPPING =
{
"Q8_0" => 8.5,
"Q6_K" => 6.59,
"Q5_K_M" => 5.69,
"Q5_K_S" => 5.54,
"Q5_0" => 5.54,
"Q4_K_M" => 4.85,
"Q4_K_S" => 4.58,
"Q4_0" => 4.55,
"IQ4_NL" => 4.5,
"Q3_K_L" => 4.27,
"IQ4_XS" => 4.25,
"Q3_K_M" => 3.91,
"IQ3_M" => 3.7,
"IQ3_S" => 3.5,
"Q3_K_S" => 3.5,
"Q2_K" => 3.35,
"IQ3_XS" => 3.3,
"IQ3_XXS" => 3.06,
"IQ2_M" => 2.7,
"IQ2_S" => 2.5,
"IQ2_XS" => 2.31,
"IQ2_XXS" => 2.06,
"IQ1_S" => 1.56,
}
EXL2_OPTIONS = (2..6).step(0.05).to_a.map{|i| i.round(2)}.reverse
def bits_to_gb(bits)
return bits / 2 ** 30 # (8 * 1024**3)
end
def calculate_vram_raw(
num_params: nil,
bpw: 5.0,
lm_head_bpw: 6.0,
kv_cache_bpw: 8.0,
context: nil,
fp8: true,
num_gpus: 1,
max_position_embeddings: nil,
num_hidden_layers: nil,
hidden_size: nil,
num_key_value_heads: nil,
num_attention_heads: nil,
intermediate_size: nil,
vocab_size: nil,
gqa: true
)
# CUDA kernel consumes between 300M and 2G, with 500M being a good estimate
# Each GPU needs its own kernel
cuda_size = 500 * 2 ** 20 * num_gpus
# Calculate the VRAM used by the parameters themselves
params_size = num_params * 1e9 * (bpw / 8)
# Calculate the amount of VRAM the context will use
kv_cache_size = (context * 2 * num_hidden_layers * hidden_size) * (kv_cache_bpw / 8)
# Models with Grouped Query Attention
if gqa
kv_cache_size *= num_key_value_heads.to_f / num_attention_heads.to_f
end
# Calculate VRAM used by activations, but with linear scaling of context due to Flash Attention
bytes_per_param = (bpw / 8)
lm_head_bytes_per_param = (lm_head_bpw / 8)
head_dim = hidden_size.to_f / num_attention_heads.to_f
attention_input = bytes_per_param * context * hidden_size
q = bytes_per_param * context * head_dim * num_attention_heads
k = bytes_per_param * context * head_dim * num_key_value_heads
v = bytes_per_param * context * head_dim * num_key_value_heads
softmax_output = lm_head_bytes_per_param * num_attention_heads * context # ** 2
softmax_dropout_mask = num_attention_heads * context # ** 2
dropout_output = lm_head_bytes_per_param * num_attention_heads * context # ** 2
out_proj_input = lm_head_bytes_per_param * context * num_attention_heads * head_dim
attention_dropout = context * hidden_size
attention_block = (
attention_input
+ q
+ k
+ softmax_output
+ v
+ out_proj_input
+ softmax_dropout_mask
+ dropout_output
+ attention_dropout
)
mlp_input = bytes_per_param * context * hidden_size
activation_input = bytes_per_param * context * intermediate_size
down_proj_input = bytes_per_param * context * intermediate_size
dropout_mask = context * hidden_size
mlp_block = mlp_input + activation_input + down_proj_input + dropout_mask
layer_norms = bytes_per_param * context * hidden_size * 2
activations_size = attention_block + mlp_block + layer_norms
# Calculate output size
output_size = lm_head_bytes_per_param * context * vocab_size
vram_bits = cuda_size + params_size + activations_size + output_size + kv_cache_size
return bits_to_gb(vram_bits).round(2)
end
def download_file(url, filename, headers: {})
# Determine the base directory for storing models
base_dir = File.expand_path(File.join(File.dirname(__FILE__), "cache"))
# Concatenate the base directory with the filename to get the full file path
file_path = File.join(base_dir, filename)
# Already cached
return if File.exist?(file_path)
# Use open-uri to open the URL and read the file in binary mode
URI.open(url, headers) do |readable|
# Ensure the model directory exists; create it if it doesn't
FileUtils.mkdir_p(File.dirname(file_path))
# Open the new file in write binary mode
File.open(file_path, 'wb') do |writable|
# Copy the contents of the URL to the new file in chunks of 1024 bytes
while chunk = readable.read(1024)
writable.write(chunk)
end
end
end
end
def download_model_configs(model_id, access: nil)
url = "https://huggingface.co/#{model_id}"
readme_url = "#{url}/raw/main/README.md"
config_url = "#{url}/raw/main/config.json"
index_url = "#{url}/raw/main/model.safetensors.index.json"
readme_path = File.join(model_id, "README.md")
config_path = File.join(model_id, "config.json")
index_path = File.join(model_id, "model.safetensors.index.json")
headers = {}
if access
headers["Authorization"] = "Bearer #{access}"
end
download_file(readme_url, readme_path, headers: headers)
download_file(config_url, config_path, headers: headers)
download_file(index_url, index_path, headers: headers)
end
def get_model_config(model_id, access: nil)
download_model_configs(model_id, access: access)
config = JSON.parse(File.read("cache/#{model_id}/config.json"))
index = JSON.parse(File.read("cache/#{model_id}/model.safetensors.index.json"))
num_params = index["metadata"]["total_size"].to_f / 2 / 1e9
config["num_params"] = num_params
return config
end
def parse_bpw(bpw)
return bpw.to_f if bpw.to_f > 0
return GGUF_MAPPING[bpw.upcase]
end
def get_bpw_values(bpw, fp8)
# Parse the bpw value if it was given as a GGUF quant ID
bpw = parse_bpw(bpw)
# Exllama2 supports specifying a different lm_head value
# Use reasonable values
if bpw > 6.0
lm_head_bpw = 8.0
else
lm_head_bpw = 6.0
end
# Exllama2 supports FP8 KV Cache, which effectively halves the context memory usage
# Default to true
if fp8
kv_cache_bpw = 8
else
kv_cache_bpw = 16
end
retval =
{
bpw: bpw,
lm_head_bpw: lm_head_bpw,
kv_cache_bpw: kv_cache_bpw,
}
return retval
end
# Mode A
def calculate_vram(model_id, bpw=5.0, context=nil, fp8=true, access: nil)
# Get the config
config = get_model_config(model_id, access: access)
# Determine all the bpw values
bpw_values = get_bpw_values(bpw, fp8)
# Context length, default to max
unless context
context = config["max_position_embeddings"]
end
return calculate_vram_raw(
num_params: config["num_params"],
bpw: bpw_values[:bpw],
lm_head_bpw: bpw_values[:lm_head_bpw],
kv_cache_bpw: bpw_values[:kv_cache_bpw],
context: context,
num_hidden_layers: config["num_hidden_layers"],
hidden_size: config["hidden_size"],
num_key_value_heads: config["num_key_value_heads"],
num_attention_heads: config["num_attention_heads"],
intermediate_size: config["intermediate_size"],
vocab_size: config["vocab_size"],
)
end
# Mode B
def calculate_context(model_id, memory=48, bpw=5.0, fp8=true, access: nil)
# Get the config
config = get_model_config(model_id, access: access)
min_context = 2048
max_context = config["max_position_embeddings"]
# Binary search to find the coarse value
low, high = min_context, max_context
while low < high
mid = (low + high + 1) / 2
if calculate_vram(model_id, bpw, mid, fp8, access: access) > memory
high = mid - 1
else
low = mid
end
end
# Linear search to find the fine value
context = low
while calculate_vram(model_id, bpw, context, fp8, access: access) < memory && context <= max_context
context += 100
end
# Found the maximum context
return context - 100
end
# Mode C
def calculate_bpw(model_id, memory=48, context=nil, fp8=true, type: :exl2, access: nil)
case type
when :exl2
EXL2_OPTIONS.each do |bpw|
return bpw if calculate_vram(model_id, bpw, context, fp8, access: access) < memory
end
when :gguf
GGUF_MAPPING.each do |name, bpw|
return name if calculate_vram(model_id, bpw, context, fp8, access: access) < memory
end
end
return nil
end
if $0 == __FILE__
require 'optparse'
options = {}
OptionParser.new do |opts|
opts.banner = "Usage: vram.rb [options] model_id"
opts.separator <<~EOF
Global Options:
EOF
opts.on("-a", "--access TOKEN", "specify your huggingface.co access token (optional)"){|v| options[:access] = v}
opts.on("-f", "--[no-]fp8", "use fp8 kv cache (default: true)"){|v| options[:fp8] = v}
opts.separator <<~EOF
Mode-specific Options:
- Mode A: Supply "--bpw" and "--context" to get the amount of VRAM required
- Mode B: Supply "--memory" and "--bpw" to get the amount of context you can fit in your VRAM
- Mode C: Supply "--memory" and "--context" to get the best BPW you can fit in your VRAM
EOF
opts.on("-m", "--mode MODE", "Select mode from above (default: A)"){|v| options[:mode] = v.downcase.to_sym}
opts.on("-b", "--bpw BPW", "Bits per Weight (default: 5)"){|v| options[:bpw] = v}
opts.on("-c", "--context CONTEXT", "set context (default: use model setting)"){|v| options[:context] = v.to_f}
opts.on("-r", "--ram RAM", "Available VRAM in GB (default: 48)"){|v| options[:memory] = v}
opts.on("-t", "--type TYPE", "Type of quantization [exl2, gguf] (default: exl2)"){|v| options[:type] = v.downcase.to_sym}
opts.separator <<~EOF
Note:
You can use numbers like "4.85" OR GGUF Quant IDs like "Q5_K_M" for the "--bpw" value.
For mode C, specifying "--type gguf" will return GGUF Quant IDs, while specifying "--type exl2" will return floating point numbers.
Examples:
# How much memory do I need to run a model quantized to IQ3_M?
./vram.rb NousResearch/Hermes-2-Theta-Llama-3-8B --mode A --bpw IQ3_M
# 5.39
# How much context can I get out of this model?
./vram.rb NousResearch/Hermes-2-Theta-Llama-3-8B --mode B --ram 6 --bpw 8
# 1948
# What is the best quant I can run of this model?
./vram.rb NousResearch/Hermes-2-Theta-Llama-3-8B --mode C --ram 6 --type gguf
# Q3_K_L
EOF
end.parse!
if ARGV.empty?
puts "Error: Model ID is required"
puts OptionParser.new { |opts| opts.banner }.help
exit 1
end
model_id = ARGV[0]
# Default values
options[:fp8] = true if options[:fp8].nil?
options[:mode] ||= :a
options[:bpw] ||= 5.0
# Context defaults to max supported by model
options[:memory] ||= 48; options[:memory] = options[:memory].to_f
options[:type] ||= :exl2
case options[:mode]
when :a
vram_required = calculate_vram(model_id, options[:bpw], options[:context], options[:fp8], access: options[:access])
puts vram_required
when :b
largest_context = calculate_context(model_id, options[:memory], options[:bpw], options[:fp8], access: options[:access])
puts largest_context
when :c
best_bpw = calculate_bpw(model_id, options[:memory], options[:context], options[:fp8], type: options[:type], access: options[:access])
if best_bpw
puts best_bpw
else
puts 0
end
end
end
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment