Created
September 23, 2024 14:19
-
-
Save robbiemu/eb689b8d93d2134e2f098e929f5befba to your computer and use it in GitHub Desktop.
vram estimator for Macs using Ollama + gollama
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
vram llama3.1:8b-instruct-q8_0 --verbose | |
VERBOSE: Default fits value from sysctl: 40.0 GB | |
VERBOSE: Quant value for llama3.1:8b-instruct-q8_0: Q8_0 | |
VERBOSE: VRAM nth for llama3.1:8b-instruct-q8_0: 131072 | |
VERBOSE: Running gollama -vram for llama3.1:8b-instruct-q8_0 with fits=40.0 GB | |
VERBOSE: VRAM output header, labels, and rows gathered | |
VERBOSE: Quant row: | Q8_0 | 8.50 | 9.1 | 10.9 | 13.4(12.4,11.9) | 18.4(16.4,15.4) | 28.3(24.3,22.3) | 48.2(40.2,36.2) | | |
VERBOSE: Max A: 28.3 at 64K | |
VERBOSE: Max B: 24.3 at 64K | |
VERBOSE: Max C: 36.2 at 128K | |
VERBOSE: Final Output: 64K @28.3 ( 64K @24.3, 128K @36.2) | |
Using fits value: 40.00 GB | |
64K @28.3 ( 64K @24.3, 128K @36.2) | |
vram llama3.1:8b-instruct-q8_0 | |
Using fits value: 40.00 GB | |
64K @28.3 ( 64K @24.3, 128K @36.2) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python3 | |
import argparse | |
import subprocess | |
import re | |
import sys | |
import logging | |
def run_command(command): | |
"""Run a shell command and return its output.""" | |
try: | |
result = subprocess.run(command, stdout=subprocess.PIPE, | |
text=True, check=True) | |
return result.stdout.strip() | |
except subprocess.CalledProcessError as e: | |
logging.error(f"Error running command {' '.join(command)}: {e}") | |
sys.exit(1) | |
def extract_quant(model_name): | |
"""Extract the quant value from gollama -l output for the given model.""" | |
output = run_command(["gollama", "-l"]) | |
pattern = re.compile(rf"{re.escape(model_name)}\s+\S+\s+(\S+)") | |
match = pattern.search(output) | |
if match: | |
quant = match.group(1) | |
logging.info(f"Quant value for {model_name}: {quant}") | |
return quant | |
else: | |
logging.error(f"Could not find quant value for model '{model_name}'") | |
sys.exit(1) | |
def extract_vram_nth(model_name): | |
"""Extract the context length (vram_nth) from ollama show <model> output.""" | |
output = run_command(["ollama", "show", model_name]) | |
pattern = re.compile(r"context length\s+(\d+)") | |
match = pattern.search(output) | |
if match: | |
vram_nth = match.group(1) | |
logging.info(f"VRAM nth for {model_name}: {vram_nth}") | |
return vram_nth | |
else: | |
logging.error(f"Could not extract context length for model '{model_name}'") | |
sys.exit(1) | |
def run_vram_estimation(model_name, vram_nth, fits_limit): | |
"""Run gollama -vram estimation and return the output.""" | |
logging.info(f"Running gollama -vram for {model_name} with fits={fits_limit} GB") | |
output = run_command([ | |
"gollama", "-vram", model_name, "--fits", str(fits_limit), | |
"--vram-to-nth", vram_nth | |
]) | |
return output | |
def find_largest_below_fits(vram_output, quant, fits): | |
"""Find the largest A, B, and C values below the fits limit, along with their column names.""" | |
lines = vram_output.splitlines() | |
header = lines[0] | |
separator = lines[1] | |
labels = lines[2] | |
rows = lines[3:] | |
logging.info("VRAM output header, labels, and rows gathered") | |
# Find the quant row | |
quant_row = None | |
for row in rows: | |
if quant in row: | |
quant_row = row | |
break | |
if not quant_row: | |
logging.error(f"Could not find matching row for quant '{quant}'") | |
sys.exit(1) | |
logging.info(f"Quant row: {quant_row}") | |
# Extract column names from labels | |
column_names = [col.strip() for col in labels.split('|')[3:]] | |
columns = quant_row.split("|")[3:] | |
max_A, max_A_ctx = None, None | |
max_B, max_B_ctx = None, None | |
max_C, max_C_ctx = None, None | |
for idx, col in enumerate(columns): | |
col = col.strip() | |
match = re.match(r'([\d\.]+)(?:\(([\d\.]+),\s*([\d\.]+)\))?', col) | |
if match: | |
A_val = float(match.group(1)) | |
B_val = float(match.group(2) or 0) | |
C_val = float(match.group(3) or 0) | |
ctx_size = column_names[idx + 1] if idx + 1 < \ | |
len(column_names) else "Unknown" | |
if A_val <= fits and (max_A is None or A_val >= max_A): | |
max_A = A_val | |
max_A_ctx = ctx_size | |
if B_val <= fits and (max_B is None or B_val >= max_B): | |
max_B = B_val | |
max_B_ctx = ctx_size | |
if C_val <= fits and (max_C is None or C_val >= max_C): | |
max_C = C_val | |
max_C_ctx = ctx_size | |
logging.info(f"Max A: {max_A} at {max_A_ctx}") | |
logging.info(f"Max B: {max_B} at {max_B_ctx}") | |
logging.info(f"Max C: {max_C} at {max_C_ctx}") | |
if max_A is not None or max_B is not None or max_C is not None: | |
final_output = f"{max_A_ctx}@{max_A} ({max_B_ctx}@{max_B}, {max_C_ctx}@{max_C})" | |
logging.info(f"Final Output: {final_output}") | |
return header, labels, separator, final_output | |
else: | |
logging.error(f"No values found below the fits limit of {fits} GB") | |
sys.exit(1) | |
def get_default_fits(): | |
"""Get the default fits value from sysctl iogpu.wired_limit_mb.""" | |
output = run_command(["sysctl", "iogpu.wired_limit_mb"]) | |
match = re.search(r"(\d+)", output) | |
if match: | |
wired_limit_mb = int(match.group(1)) | |
fits = wired_limit_mb / 1024 # Convert MB to GB | |
logging.info(f"Default fits value from sysctl: {fits} GB") | |
return fits | |
else: | |
logging.error("Could not retrieve iogpu.wired_limit_mb from sysctl") | |
sys.exit(1) | |
if __name__ == "__main__": | |
parser = argparse\ | |
.ArgumentParser(description="Estimate VRAM usage for a given model.") | |
parser.add_argument("model_name", help="Name of the model") | |
parser.add_argument( | |
"--fits", type=float, default=None, | |
help="Fits limit in GB (default: iogpu.wired_limit_mb / 1024)" | |
) | |
parser.add_argument("--verbose", "-v", | |
action="store_true", help="Verbose output") | |
args = parser.parse_args() | |
# Set up logging with 'VERBOSE' instead of 'DEBUG' | |
logging.basicConfig(level=logging.DEBUG if args.verbose else logging.ERROR, | |
format='VERBOSE: %(message)s' if args.verbose else '%(message)s', | |
stream=sys.stderr) | |
model_name = args.model_name | |
fits_limit = args.fits or get_default_fits() | |
quant = extract_quant(model_name) | |
vram_nth = extract_vram_nth(model_name) | |
vram_output = run_vram_estimation(model_name, vram_nth, fits_limit) | |
header, labels, separator, largest_col = \ | |
find_largest_below_fits(vram_output, quant, fits_limit) | |
print(f"Using fits value: {fits_limit:.2f} GB") | |
print(largest_col) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment