Created
May 5, 2025 06:57
-
-
Save N8python/6f70660db9255492dafb29c2dbef6e89 to your computer and use it in GitHub Desktop.
A shamless ripoff of https://github.com/ml-explore/mlx-lm/pull/148 with grad accum.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import argparse | |
import copy | |
import glob | |
import shutil | |
import time | |
import types | |
from pathlib import Path | |
import mlx.core as mx | |
import mlx.nn as nn | |
import mlx.optimizers as optimizers | |
import numpy as np | |
from mlx.utils import tree_flatten, tree_map | |
from mlx_lm.tokenizer_utils import TokenizerWrapper | |
from mlx_lm.tuner.datasets import load_dataset | |
from mlx_lm.tuner.trainer import iterate_batches | |
from mlx_lm.tuner.utils import print_trainable_parameters | |
from mlx_lm.utils import ( | |
create_model_card, | |
fetch_from_hub, | |
get_model_path, | |
quantize_model, | |
save_config, | |
save_weights, | |
) | |
def dwq_quantize( | |
model, | |
q_model, | |
opt, | |
data, | |
batch_size: int = 2, | |
max_seq_length: int = 2048, | |
temperature: float = 0.5, | |
dtype: mx.Dtype = mx.bfloat16, | |
grad_accum_steps: int = 1, | |
): | |
"""Distillationβaware weight quantization with gradient accumulation. | |
After every *optimizer* step (i.e. after `grad_accum_steps` microβbatches), | |
the function now logs the **mean KLβdiv loss across those microβbatches** | |
instead of spamming once per microβbatch. | |
""" | |
assert grad_accum_steps > 0, "grad_accum_steps must be >= 1" | |
group = mx.distributed.init() | |
world_size = group.size() | |
rank = group.rank() | |
# ββββββββββββββββββββ Trainable parameters ββββββββββββββββββββββββββββββββ | |
def unfreeze(_, m): | |
if hasattr(m, "bits") and hasattr(m, "group_size"): | |
m.unfreeze(keys=["scales", "biases"], recurse=False) | |
q_model.apply_to_modules(unfreeze) | |
print_trainable_parameters(q_model) | |
# βββββββββββββββββββββββ Helper functions ββββββββββββββββββββββββββββββββ | |
def log_norm(x): | |
x = x * (1 / temperature) | |
return x - mx.logsumexp(x, axis=-1, keepdims=True) | |
def loss_fn(params, x, targets, lengths): | |
q_model.update(tree_map(lambda t: t.astype(dtype), params)) | |
logits = q_model(x).astype(mx.float32) | |
losses = nn.losses.kl_div_loss(log_norm(logits), targets, reduction="none") | |
mask = mx.arange(targets.shape[1]) < lengths[:, 1:] | |
ntoks = mask.sum() | |
loss = (mask * losses).sum() / ntoks | |
return loss, ntoks | |
def calc_grads(inputs, targets, lengths, params): | |
(loss, ntoks), grads = mx.value_and_grad(loss_fn)( | |
params, inputs, targets, lengths | |
) | |
grads = nn.average_gradients(grads) | |
return loss, ntoks, grads | |
# Keep params in fp32 while learning. | |
params = tree_map(lambda x: x.astype(mx.float32), q_model.trainable_parameters()) | |
# Accumulation buffers. | |
grad_accum = tree_map(lambda p: mx.zeros_like(p), params) | |
accum_counter = 0 | |
# Logging helpers. | |
accum_loss_sum = 0.0 | |
accum_loss_count = 0 | |
global_step = 0 # Counts *optimizer* steps. | |
tokens = 0 | |
tic = time.time() | |
for it, (batch, lengths) in enumerate( | |
iterate_batches(data, batch_size, max_seq_length) | |
): | |
# βββββββββββββ Teacher forward pass (no grad) ββββββββββββββββββββββββ | |
targets = log_norm(model(batch).astype(mx.float32)) | |
mx.eval(targets) | |
mx.clear_cache() | |
# βββββββββββββ Student forward/backward ββββββββββββββββββββββββββββββ | |
loss, ntoks, grads = calc_grads(batch, targets, lengths, params) | |
mx.eval(loss, grads) | |
mx.clear_cache() | |
# Distributed reduction for consistent loss/ntoks across devices. | |
loss_red = mx.distributed.all_sum(loss, stream=mx.cpu).item() / world_size | |
ntoks_red = mx.distributed.all_sum(ntoks, stream=mx.cpu).item() | |
tokens += ntoks_red | |
accum_loss_sum += loss_red | |
accum_loss_count += 1 | |
# βββββββββββββ Gradient accumulation ββββββββββββββββββββββββββββββββ | |
grad_accum = tree_map(lambda a, b: a + b, grad_accum, grads) | |
accum_counter += 1 | |
step_now = accum_counter == grad_accum_steps | |
last_batch = it == len(data) - 1 | |
if step_now or last_batch: | |
scale = accum_counter # May be < grad_accum_steps on tail batch. | |
avg_grads = tree_map(lambda g: g / scale, grad_accum) | |
params = opt.apply_gradients(avg_grads, params) | |
# ββββββββββββββββ Logging (once per *optimizer* step) ββββββββββ | |
avg_step_loss = accum_loss_sum / accum_loss_count | |
global_step += 1 | |
toks_per_sec = tokens / (time.time() - tic) | |
if rank == 0: | |
print( | |
f"step={global_step}, avg_loss={avg_step_loss:.3f}, " | |
f"tokens={tokens}, toks_per_sec={toks_per_sec:.3f}", | |
flush=True, | |
) | |
# Reset accumulators for next step. | |
grad_accum = tree_map(lambda p: mx.zeros_like(p), grad_accum) | |
accum_counter = 0 | |
accum_loss_sum = 0.0 | |
accum_loss_count = 0 | |
# Push learned params back into student model (cast to final dtype). | |
q_model.update(tree_map(lambda x: x.astype(dtype), params)) | |
def save_model( | |
model: nn.Module, | |
tokenizer: TokenizerWrapper, | |
config, | |
model_path: Path, | |
mlx_path: str, | |
hf_path: str, | |
): | |
weights = dict(tree_flatten(model.parameters())) | |
mlx_path = Path(mlx_path) | |
save_weights(mlx_path, weights, donate_weights=True) | |
for file in glob.glob(str(model_path / "*.py")): | |
shutil.copy(file, mlx_path) | |
tokenizer.save_pretrained(mlx_path) | |
save_config(config, config_path=mlx_path / "config.json") | |
create_model_card(mlx_path, hf_path) | |
def load_data(tokenizer, data_path: str, num_samples: int): | |
args = types.SimpleNamespace( | |
hf_dataset={ | |
"path": data_path, | |
"train_split": f"train[:{num_samples}]", | |
"valid_split": "train[:1]", | |
}, | |
train=True, | |
test=False, | |
) | |
dataset = load_dataset(args, tokenizer)[0] | |
return [dataset.process(d) for d in dataset] | |
def main(): | |
parser = argparse.ArgumentParser() | |
parser.add_argument("--model", "-m", default="Qwen/Qwen3-1.7B") | |
parser.add_argument("--quantized-model", default=None) | |
parser.add_argument("--mlx-path", default="mlx_model", help="Path to save the quantized model.") | |
parser.add_argument("--bits", type=int, default=4, help="Bits per weight for quantization.") | |
parser.add_argument("--group-size", type=int, default=64, help="Group size for quantization.") | |
parser.add_argument("--num-samples", type=int, default=1024, help="Number of samples to use for training.") | |
parser.add_argument("--max-seq-length", type=int, default=2048) | |
parser.add_argument("--seed", type=int, default=123) | |
parser.add_argument("--learning-rate", type=float, default=1e-5) | |
parser.add_argument("--batch-size", type=int, default=8) | |
parser.add_argument("--data-path", type=str, default="allenai/tulu-3-sft-mixture", help="HuggingFace dataset path.") | |
parser.add_argument("--temperature", type=float, default=0.5, help="Temperature scaling for the loss.") | |
parser.add_argument("--grad-accum-steps", type=int, default=1, help="Microβbatches per optimizer step.") | |
args = parser.parse_args() | |
group = mx.distributed.init() | |
num_samples = args.num_samples | |
if num_samples % group.size() > 0: | |
num_samples += group.size() - num_samples % group.size() | |
np.random.seed(args.seed) | |
mx.random.seed(args.seed) | |
model_path = get_model_path(args.model, revision=None) | |
model, config, tokenizer = fetch_from_hub(model_path, lazy=True) | |
calibration_data = load_data(tokenizer, args.data_path, args.num_samples) | |
if args.quantized_model is not None: | |
q_model_path = get_model_path(args.quantized_model, revision=None) | |
q_model, config, _ = fetch_from_hub(q_model_path, lazy=True) | |
else: | |
q_model = copy.deepcopy(model) | |
_, config = quantize_model(q_model, config, q_group_size=args.group_size, q_bits=args.bits) | |
opt = optimizers.Adam(learning_rate=args.learning_rate, bias_correction=True) | |
dwq_quantize( | |
model, | |
q_model, | |
opt, | |
calibration_data, | |
batch_size=args.batch_size, | |
max_seq_length=args.max_seq_length, | |
temperature=args.temperature, | |
grad_accum_steps=args.grad_accum_steps, | |
) | |
save_model(q_model, tokenizer, config, model_path, args.mlx_path, args.model) | |
if __name__ == "__main__": | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment