Skip to content

Instantly share code, notes, and snippets.

@robbiemu
Created September 23, 2024 14:19
Show Gist options
  • Save robbiemu/eb689b8d93d2134e2f098e929f5befba to your computer and use it in GitHub Desktop.
Save robbiemu/eb689b8d93d2134e2f098e929f5befba to your computer and use it in GitHub Desktop.
vram estimator for Macs using Ollama + gollama
vram llama3.1:8b-instruct-q8_0 --verbose
VERBOSE: Default fits value from sysctl: 40.0 GB
VERBOSE: Quant value for llama3.1:8b-instruct-q8_0: Q8_0
VERBOSE: VRAM nth for llama3.1:8b-instruct-q8_0: 131072
VERBOSE: Running gollama -vram for llama3.1:8b-instruct-q8_0 with fits=40.0 GB
VERBOSE: VRAM output header, labels, and rows gathered
VERBOSE: Quant row: | Q8_0 | 8.50 | 9.1 | 10.9 | 13.4(12.4,11.9) | 18.4(16.4,15.4) | 28.3(24.3,22.3) | 48.2(40.2,36.2) |
VERBOSE: Max A: 28.3 at 64K
VERBOSE: Max B: 24.3 at 64K
VERBOSE: Max C: 36.2 at 128K
VERBOSE: Final Output: 64K @28.3 ( 64K @24.3, 128K @36.2)
Using fits value: 40.00 GB
64K @28.3 ( 64K @24.3, 128K @36.2)
vram llama3.1:8b-instruct-q8_0
Using fits value: 40.00 GB
64K @28.3 ( 64K @24.3, 128K @36.2)
#!/usr/bin/env python3
import argparse
import subprocess
import re
import sys
import logging
def run_command(command):
"""Run a shell command and return its output."""
try:
result = subprocess.run(command, stdout=subprocess.PIPE,
text=True, check=True)
return result.stdout.strip()
except subprocess.CalledProcessError as e:
logging.error(f"Error running command {' '.join(command)}: {e}")
sys.exit(1)
def extract_quant(model_name):
"""Extract the quant value from gollama -l output for the given model."""
output = run_command(["gollama", "-l"])
pattern = re.compile(rf"{re.escape(model_name)}\s+\S+\s+(\S+)")
match = pattern.search(output)
if match:
quant = match.group(1)
logging.info(f"Quant value for {model_name}: {quant}")
return quant
else:
logging.error(f"Could not find quant value for model '{model_name}'")
sys.exit(1)
def extract_vram_nth(model_name):
"""Extract the context length (vram_nth) from ollama show <model> output."""
output = run_command(["ollama", "show", model_name])
pattern = re.compile(r"context length\s+(\d+)")
match = pattern.search(output)
if match:
vram_nth = match.group(1)
logging.info(f"VRAM nth for {model_name}: {vram_nth}")
return vram_nth
else:
logging.error(f"Could not extract context length for model '{model_name}'")
sys.exit(1)
def run_vram_estimation(model_name, vram_nth, fits_limit):
"""Run gollama -vram estimation and return the output."""
logging.info(f"Running gollama -vram for {model_name} with fits={fits_limit} GB")
output = run_command([
"gollama", "-vram", model_name, "--fits", str(fits_limit),
"--vram-to-nth", vram_nth
])
return output
def find_largest_below_fits(vram_output, quant, fits):
"""Find the largest A, B, and C values below the fits limit, along with their column names."""
lines = vram_output.splitlines()
header = lines[0]
separator = lines[1]
labels = lines[2]
rows = lines[3:]
logging.info("VRAM output header, labels, and rows gathered")
# Find the quant row
quant_row = None
for row in rows:
if quant in row:
quant_row = row
break
if not quant_row:
logging.error(f"Could not find matching row for quant '{quant}'")
sys.exit(1)
logging.info(f"Quant row: {quant_row}")
# Extract column names from labels
column_names = [col.strip() for col in labels.split('|')[3:]]
columns = quant_row.split("|")[3:]
max_A, max_A_ctx = None, None
max_B, max_B_ctx = None, None
max_C, max_C_ctx = None, None
for idx, col in enumerate(columns):
col = col.strip()
match = re.match(r'([\d\.]+)(?:\(([\d\.]+),\s*([\d\.]+)\))?', col)
if match:
A_val = float(match.group(1))
B_val = float(match.group(2) or 0)
C_val = float(match.group(3) or 0)
ctx_size = column_names[idx + 1] if idx + 1 < \
len(column_names) else "Unknown"
if A_val <= fits and (max_A is None or A_val >= max_A):
max_A = A_val
max_A_ctx = ctx_size
if B_val <= fits and (max_B is None or B_val >= max_B):
max_B = B_val
max_B_ctx = ctx_size
if C_val <= fits and (max_C is None or C_val >= max_C):
max_C = C_val
max_C_ctx = ctx_size
logging.info(f"Max A: {max_A} at {max_A_ctx}")
logging.info(f"Max B: {max_B} at {max_B_ctx}")
logging.info(f"Max C: {max_C} at {max_C_ctx}")
if max_A is not None or max_B is not None or max_C is not None:
final_output = f"{max_A_ctx}@{max_A} ({max_B_ctx}@{max_B}, {max_C_ctx}@{max_C})"
logging.info(f"Final Output: {final_output}")
return header, labels, separator, final_output
else:
logging.error(f"No values found below the fits limit of {fits} GB")
sys.exit(1)
def get_default_fits():
"""Get the default fits value from sysctl iogpu.wired_limit_mb."""
output = run_command(["sysctl", "iogpu.wired_limit_mb"])
match = re.search(r"(\d+)", output)
if match:
wired_limit_mb = int(match.group(1))
fits = wired_limit_mb / 1024 # Convert MB to GB
logging.info(f"Default fits value from sysctl: {fits} GB")
return fits
else:
logging.error("Could not retrieve iogpu.wired_limit_mb from sysctl")
sys.exit(1)
if __name__ == "__main__":
parser = argparse\
.ArgumentParser(description="Estimate VRAM usage for a given model.")
parser.add_argument("model_name", help="Name of the model")
parser.add_argument(
"--fits", type=float, default=None,
help="Fits limit in GB (default: iogpu.wired_limit_mb / 1024)"
)
parser.add_argument("--verbose", "-v",
action="store_true", help="Verbose output")
args = parser.parse_args()
# Set up logging with 'VERBOSE' instead of 'DEBUG'
logging.basicConfig(level=logging.DEBUG if args.verbose else logging.ERROR,
format='VERBOSE: %(message)s' if args.verbose else '%(message)s',
stream=sys.stderr)
model_name = args.model_name
fits_limit = args.fits or get_default_fits()
quant = extract_quant(model_name)
vram_nth = extract_vram_nth(model_name)
vram_output = run_vram_estimation(model_name, vram_nth, fits_limit)
header, labels, separator, largest_col = \
find_largest_below_fits(vram_output, quant, fits_limit)
print(f"Using fits value: {fits_limit:.2f} GB")
print(largest_col)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment