Skip to content

Instantly share code, notes, and snippets.

@awni
Created October 17, 2024 15:57
Show Gist options
  • Save awni/1f6ffa9297384f49ae54d299d148bef2 to your computer and use it in GitHub Desktop.
Save awni/1f6ffa9297384f49ae54d299d148bef2 to your computer and use it in GitHub Desktop.
Faster CPU HF to MLX conversion script
import argparse
from functools import partial
import multiprocessing as mp
from typing import Callable, Optional
import mlx.core as mx
import mlx.nn as nn
from mlx.utils import tree_map_with_path
from mlx_lm.utils import *
mx.quantize = mx.compile(mx.quantize)
MAX_NUM_THREADS = 1
class StreamPool:
def __init__(self):
num_cpus = mp.cpu_count()
self._pool = [mx.new_stream(mx.cpu) for _ in range(min(num_cpus, MAX_NUM_THREADS))]
self._idx = -1
def next(self):
self._idx += 1
self._idx %= len(self._pool)
return self._pool[self._idx]
stream_pool = StreamPool()
def multi_quantize(
model: nn.Module,
group_size: int = 64,
bits: int = 4,
class_predicate: Optional[Callable] = None,
dtype: mx.Dtype = mx.float16,
):
class_predicate = class_predicate or (lambda _, m: hasattr(m, "to_quantized"))
def _maybe_quantize(path, m):
with mx.stream(stream_pool.next()):
m.set_dtype(dtype)
if class_predicate(path, m):
if hasattr(m, "to_quantized"):
return m.to_quantized(group_size, bits)
else:
raise ValueError(f"Unable to quantize model of type {type(m)}")
else:
return m
leaves = model.leaf_modules()
leaves = tree_map_with_path(_maybe_quantize, leaves, is_leaf=nn.Module.is_module)
model.update_modules(leaves)
def convert(
hf_path: str,
mlx_path: str = "mlx_model",
quantize: bool = False,
q_group_size: int = 64,
q_bits: int = 4,
dtype: Optional[str] = None,
upload_repo: str = None,
revision: Optional[str] = None,
dequantize: bool = False,
):
# Check the save path is empty
if isinstance(mlx_path, str):
mlx_path = Path(mlx_path)
if mlx_path.exists():
raise ValueError(
f"Cannot save to the path {mlx_path} as it already exists."
" Please delete the file/directory or specify a new path to save to."
)
print("[INFO] Loading")
model_path = get_model_path(hf_path, revision=revision)
model, config, tokenizer = fetch_from_hub(model_path, lazy=True)
weights = tree_flatten(model.parameters())
dtype = getattr(mx, dtype) if dtype else weights[0][1].dtype
weights = dict(weights)
if quantize and dequantize:
raise ValueError("Choose either quantize or dequantize, not both.")
if quantize:
print("[INFO] Quantizing")
model.load_weights(list(weights.items()))
nn.quantize = partial(multi_quantize, dtype=dtype)
weights, config = quantize_model(model, config, q_group_size, q_bits)
if dequantize:
print("[INFO] Dequantizing")
model = dequantize_model(model)
weights = dict(tree_flatten(model.parameters()))
del model
save_weights(mlx_path, weights, donate_weights=True)
py_files = glob.glob(str(model_path / "*.py"))
for file in py_files:
shutil.copy(file, mlx_path)
tokenizer.save_pretrained(mlx_path)
save_config(config, config_path=mlx_path / "config.json")
if upload_repo is not None:
upload_to_hub(mlx_path, upload_repo, hf_path)
def make_parser() -> argparse.ArgumentParser:
"""
Configures and returns the argument parser for the script.
Returns:
argparse.ArgumentParser: Configured argument parser.
"""
parser = argparse.ArgumentParser(
description="Convert Hugging Face model to MLX format"
)
parser.add_argument("--hf-path", type=str, help="Path to the Hugging Face model.")
parser.add_argument(
"--mlx-path", type=str, default="mlx_model", help="Path to save the MLX model."
)
parser.add_argument(
"-q", "--quantize", help="Generate a quantized model.", action="store_true"
)
parser.add_argument(
"--q-group-size", help="Group size for quantization.", type=int, default=64
)
parser.add_argument(
"--q-bits", help="Bits per weight for quantization.", type=int, default=4
)
parser.add_argument(
"--dtype",
help="Type to save the parameters.",
type=str,
choices=["float16", "bfloat16", "float32"],
default="float16",
)
parser.add_argument(
"--upload-repo",
help="The Hugging Face repo to upload the model to.",
type=str,
default=None,
)
parser.add_argument(
"-d",
"--dequantize",
help="Dequantize a quantized model.",
action="store_true",
default=False,
)
return parser
if __name__ == "__main__":
parser = make_parser()
args = parser.parse_args()
convert(**vars(args))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment