Created
October 17, 2024 15:57
-
-
Save awni/1f6ffa9297384f49ae54d299d148bef2 to your computer and use it in GitHub Desktop.
Faster CPU HF to MLX conversion script
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import argparse | |
from functools import partial | |
import multiprocessing as mp | |
from typing import Callable, Optional | |
import mlx.core as mx | |
import mlx.nn as nn | |
from mlx.utils import tree_map_with_path | |
from mlx_lm.utils import * | |
mx.quantize = mx.compile(mx.quantize) | |
MAX_NUM_THREADS = 1 | |
class StreamPool: | |
def __init__(self): | |
num_cpus = mp.cpu_count() | |
self._pool = [mx.new_stream(mx.cpu) for _ in range(min(num_cpus, MAX_NUM_THREADS))] | |
self._idx = -1 | |
def next(self): | |
self._idx += 1 | |
self._idx %= len(self._pool) | |
return self._pool[self._idx] | |
stream_pool = StreamPool() | |
def multi_quantize( | |
model: nn.Module, | |
group_size: int = 64, | |
bits: int = 4, | |
class_predicate: Optional[Callable] = None, | |
dtype: mx.Dtype = mx.float16, | |
): | |
class_predicate = class_predicate or (lambda _, m: hasattr(m, "to_quantized")) | |
def _maybe_quantize(path, m): | |
with mx.stream(stream_pool.next()): | |
m.set_dtype(dtype) | |
if class_predicate(path, m): | |
if hasattr(m, "to_quantized"): | |
return m.to_quantized(group_size, bits) | |
else: | |
raise ValueError(f"Unable to quantize model of type {type(m)}") | |
else: | |
return m | |
leaves = model.leaf_modules() | |
leaves = tree_map_with_path(_maybe_quantize, leaves, is_leaf=nn.Module.is_module) | |
model.update_modules(leaves) | |
def convert( | |
hf_path: str, | |
mlx_path: str = "mlx_model", | |
quantize: bool = False, | |
q_group_size: int = 64, | |
q_bits: int = 4, | |
dtype: Optional[str] = None, | |
upload_repo: str = None, | |
revision: Optional[str] = None, | |
dequantize: bool = False, | |
): | |
# Check the save path is empty | |
if isinstance(mlx_path, str): | |
mlx_path = Path(mlx_path) | |
if mlx_path.exists(): | |
raise ValueError( | |
f"Cannot save to the path {mlx_path} as it already exists." | |
" Please delete the file/directory or specify a new path to save to." | |
) | |
print("[INFO] Loading") | |
model_path = get_model_path(hf_path, revision=revision) | |
model, config, tokenizer = fetch_from_hub(model_path, lazy=True) | |
weights = tree_flatten(model.parameters()) | |
dtype = getattr(mx, dtype) if dtype else weights[0][1].dtype | |
weights = dict(weights) | |
if quantize and dequantize: | |
raise ValueError("Choose either quantize or dequantize, not both.") | |
if quantize: | |
print("[INFO] Quantizing") | |
model.load_weights(list(weights.items())) | |
nn.quantize = partial(multi_quantize, dtype=dtype) | |
weights, config = quantize_model(model, config, q_group_size, q_bits) | |
if dequantize: | |
print("[INFO] Dequantizing") | |
model = dequantize_model(model) | |
weights = dict(tree_flatten(model.parameters())) | |
del model | |
save_weights(mlx_path, weights, donate_weights=True) | |
py_files = glob.glob(str(model_path / "*.py")) | |
for file in py_files: | |
shutil.copy(file, mlx_path) | |
tokenizer.save_pretrained(mlx_path) | |
save_config(config, config_path=mlx_path / "config.json") | |
if upload_repo is not None: | |
upload_to_hub(mlx_path, upload_repo, hf_path) | |
def make_parser() -> argparse.ArgumentParser: | |
""" | |
Configures and returns the argument parser for the script. | |
Returns: | |
argparse.ArgumentParser: Configured argument parser. | |
""" | |
parser = argparse.ArgumentParser( | |
description="Convert Hugging Face model to MLX format" | |
) | |
parser.add_argument("--hf-path", type=str, help="Path to the Hugging Face model.") | |
parser.add_argument( | |
"--mlx-path", type=str, default="mlx_model", help="Path to save the MLX model." | |
) | |
parser.add_argument( | |
"-q", "--quantize", help="Generate a quantized model.", action="store_true" | |
) | |
parser.add_argument( | |
"--q-group-size", help="Group size for quantization.", type=int, default=64 | |
) | |
parser.add_argument( | |
"--q-bits", help="Bits per weight for quantization.", type=int, default=4 | |
) | |
parser.add_argument( | |
"--dtype", | |
help="Type to save the parameters.", | |
type=str, | |
choices=["float16", "bfloat16", "float32"], | |
default="float16", | |
) | |
parser.add_argument( | |
"--upload-repo", | |
help="The Hugging Face repo to upload the model to.", | |
type=str, | |
default=None, | |
) | |
parser.add_argument( | |
"-d", | |
"--dequantize", | |
help="Dequantize a quantized model.", | |
action="store_true", | |
default=False, | |
) | |
return parser | |
if __name__ == "__main__": | |
parser = make_parser() | |
args = parser.parse_args() | |
convert(**vars(args)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment