Created
December 8, 2024 20:50
-
-
Save barronalex/0f816636443ca7b6821f31ef488da2a6 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import argparse | |
import json | |
import os | |
import tempfile | |
from pathlib import Path | |
from typing import Union | |
import lm_eval | |
from mlx_lm.utils import convert | |
from mlx_lm.evaluate import MLXLM | |
from mlx.nn import Module | |
def mixed_4_6(path: str, module: Module, config: dict): | |
"""A mixed quantization with similar choices to llama.cpp's Q4_K_M.""" | |
if not hasattr(module, "to_quantized"): | |
return False | |
index = int(path.split(".")[2]) if len(path.split(".")) > 2 else 0 | |
num_layers = config["num_hidden_layers"] | |
use_more_bits = ( | |
index < num_layers // 8 or | |
index >= 7 * num_layers // 8 or | |
(index - num_layers // 8) % 3 == 2 | |
) | |
if "v_proj" in path and use_more_bits: | |
return {"group_size": 64, "bits": 6} | |
if "down_proj" in path and use_more_bits: | |
return {"group_size": 64, "bits": 6} | |
if "lm_head" in path: | |
return {"group_size": 64, "bits": 6} | |
return {"group_size": 64, "bits": 4} | |
def mixed_3_4(path: str, module: Module, config: dict): | |
"""A mixed 3/4/6 bit quantization.""" | |
if not hasattr(module, "to_quantized"): | |
return False | |
index = int(path.split(".")[2]) if len(path.split(".")) > 2 else 0 | |
num_layers = config["num_hidden_layers"] | |
use_more_bits = ( | |
index < num_layers // 8 or | |
index >= 7 * num_layers // 8 or | |
(index - num_layers // 8) % 3 == 2 | |
) | |
if "v_proj" in path: | |
return {"group_size": 64, "bits": 4} | |
if "down_proj" in path and use_more_bits: | |
return {"group_size": 64, "bits": 4} | |
if "lm_head" in path: | |
return {"group_size": 64, "bits": 6} | |
return {"group_size": 64, "bits": 3} | |
if __name__ == "__main__": | |
parser = argparse.ArgumentParser( | |
description="Convert Hugging Face model to MLX format" | |
) | |
parser.add_argument("--hf-path", type=str, help="Path to the Hugging Face model.", required=True) | |
parser.add_argument("--output-dir", type=str, help="Path to store output.", required=True) | |
parser.add_argument("--seed", type=int, help="Random seed", default=123) | |
args = parser.parse_args() | |
os.environ["TOKENIZERS_PARALLELISM"] = "false" | |
tasks = ["arc_challenge", "arc_easy", "winogrande", "boolq", "piqa", "social_iqa", "hellaswag", "openbookqa"] | |
output_dir = Path(args.output_dir) | |
output_dir.mkdir(parents=True, exist_ok=True) | |
benchmarks = [ | |
(16, None, None), | |
(4, 64, mixed_4_6), | |
(3, 64, mixed_3_4), | |
(8, 64, None), | |
(6, 64, None), | |
(4, 32, None), | |
(4, 64, None), | |
(4, 128, None), | |
(3, 32, None), | |
(3, 64, None), | |
(3, 128, None), | |
(2, 32, None), | |
] | |
for bits, group_size, quant_predicate in benchmarks: | |
with tempfile.TemporaryDirectory() as tmp_dir: | |
if bits == 16: | |
model_dir = args.hf_path | |
else: | |
model_dir = Path(tmp_dir) / "mlx_model" | |
convert(args.hf_path, model_dir, quantize=True, q_bits=bits, q_group_size=group_size, quant_predicate=quant_predicate) | |
lm = MLXLM(model_dir) | |
results = lm_eval.simple_evaluate( | |
model=lm, | |
tasks=tasks, | |
num_fewshot=0, | |
random_seed=args.seed, | |
numpy_random_seed=args.seed, | |
torch_random_seed=args.seed, | |
fewshot_random_seed=args.seed, | |
) | |
qname = "None" | |
if quant_predicate: | |
qname = quant_predicate.__name__ | |
filename = f"eval_{bits}_{group_size}_{qname}.json" | |
output_path = output_dir / filename | |
output_path.write_text(json.dumps(results["results"], indent=4)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment