Skip to content

Instantly share code, notes, and snippets.

@barronalex
Created December 8, 2024 20:50
Show Gist options
  • Save barronalex/0f816636443ca7b6821f31ef488da2a6 to your computer and use it in GitHub Desktop.
Save barronalex/0f816636443ca7b6821f31ef488da2a6 to your computer and use it in GitHub Desktop.
import argparse
import json
import os
import tempfile
from pathlib import Path
from typing import Union
import lm_eval
from mlx_lm.utils import convert
from mlx_lm.evaluate import MLXLM
from mlx.nn import Module
def mixed_4_6(path: str, module: Module, config: dict):
"""A mixed quantization with similar choices to llama.cpp's Q4_K_M."""
if not hasattr(module, "to_quantized"):
return False
index = int(path.split(".")[2]) if len(path.split(".")) > 2 else 0
num_layers = config["num_hidden_layers"]
use_more_bits = (
index < num_layers // 8 or
index >= 7 * num_layers // 8 or
(index - num_layers // 8) % 3 == 2
)
if "v_proj" in path and use_more_bits:
return {"group_size": 64, "bits": 6}
if "down_proj" in path and use_more_bits:
return {"group_size": 64, "bits": 6}
if "lm_head" in path:
return {"group_size": 64, "bits": 6}
return {"group_size": 64, "bits": 4}
def mixed_3_4(path: str, module: Module, config: dict):
"""A mixed 3/4/6 bit quantization."""
if not hasattr(module, "to_quantized"):
return False
index = int(path.split(".")[2]) if len(path.split(".")) > 2 else 0
num_layers = config["num_hidden_layers"]
use_more_bits = (
index < num_layers // 8 or
index >= 7 * num_layers // 8 or
(index - num_layers // 8) % 3 == 2
)
if "v_proj" in path:
return {"group_size": 64, "bits": 4}
if "down_proj" in path and use_more_bits:
return {"group_size": 64, "bits": 4}
if "lm_head" in path:
return {"group_size": 64, "bits": 6}
return {"group_size": 64, "bits": 3}
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Convert Hugging Face model to MLX format"
)
parser.add_argument("--hf-path", type=str, help="Path to the Hugging Face model.", required=True)
parser.add_argument("--output-dir", type=str, help="Path to store output.", required=True)
parser.add_argument("--seed", type=int, help="Random seed", default=123)
args = parser.parse_args()
os.environ["TOKENIZERS_PARALLELISM"] = "false"
tasks = ["arc_challenge", "arc_easy", "winogrande", "boolq", "piqa", "social_iqa", "hellaswag", "openbookqa"]
output_dir = Path(args.output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
benchmarks = [
(16, None, None),
(4, 64, mixed_4_6),
(3, 64, mixed_3_4),
(8, 64, None),
(6, 64, None),
(4, 32, None),
(4, 64, None),
(4, 128, None),
(3, 32, None),
(3, 64, None),
(3, 128, None),
(2, 32, None),
]
for bits, group_size, quant_predicate in benchmarks:
with tempfile.TemporaryDirectory() as tmp_dir:
if bits == 16:
model_dir = args.hf_path
else:
model_dir = Path(tmp_dir) / "mlx_model"
convert(args.hf_path, model_dir, quantize=True, q_bits=bits, q_group_size=group_size, quant_predicate=quant_predicate)
lm = MLXLM(model_dir)
results = lm_eval.simple_evaluate(
model=lm,
tasks=tasks,
num_fewshot=0,
random_seed=args.seed,
numpy_random_seed=args.seed,
torch_random_seed=args.seed,
fewshot_random_seed=args.seed,
)
qname = "None"
if quant_predicate:
qname = quant_predicate.__name__
filename = f"eval_{bits}_{group_size}_{qname}.json"
output_path = output_dir / filename
output_path.write_text(json.dumps(results["results"], indent=4))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment