Skip to content

Instantly share code, notes, and snippets.

@Blaizzy
Last active January 13, 2026 17:15
Show Gist options
  • Select an option

  • Save Blaizzy/4e3098ac059448ffc57668b712cda1a2 to your computer and use it in GitHub Desktop.

Select an option

Save Blaizzy/4e3098ac059448ffc57668b712cda1a2 to your computer and use it in GitHub Desktop.
Decode Stream
import json
from functools import partial
from json import JSONDecodeError
from typing import List
from transformers import AutoTokenizer
import tokenizers
REPLACEMENT_CHAR = "\ufffd"
def _remove_space(x):
if x and x[0] == " ":
return x[1:]
return x
class StreamingDetokenizer:
"""The streaming detokenizer interface so that we can detokenize one token at a time.
Example usage is as follows:
detokenizer = ...
# Reset the tokenizer state
detokenizer.reset()
for token in generate(...):
detokenizer.add_token(token.item())
# Contains the whole text so far. Some tokens may not be included
# since it contains whole words usually.
detokenizer.text
# Contains the printable segment (usually a word) since the last
# time it was accessed
detokenizer.last_segment
# Contains all the tokens added so far
detokenizer.tokens
# Make sure that we detokenize any remaining tokens
detokenizer.finalize()
# Now detokenizer.text should match tokenizer.decode(detokenizer.tokens)
"""
def __init__(self, tokenizer):
self.tokenizer = tokenizer._tokenizer
self.stream_decoder = tokenizers.decoders.DecodeStream(skip_special_tokens=True)
self.reset()
def reset(self):
self.offset = 0
self._tokens = []
self._text = ""
self._current_tokens = []
self._current_text = ""
def add_token(self, token, skip_special_token_ids: List[int] = []):
if token in skip_special_token_ids:
return
self._current_tokens.append(token)
# Decode immediately, one token at a time
decoded = self.stream_decoder.step(self.tokenizer, token)
if decoded:
self._current_text += decoded
def finalize(self):
self._tokens.extend(self._current_tokens)
self._text += self._current_text
self._current_tokens = []
self._current_text = ""
@property
def text(self):
if self._current_tokens:
self._current_text = self.stream_decoder.step(self.tokenizer, self._current_tokens)
if self._current_text and self._current_text[-1] == "\n":
self._tokens.extend(self._current_tokens)
self._text += self._current_text
self._current_tokens.clear()
self._current_text = ""
return self._text + self._current_text
@property
def tokens(self):
return self._tokens
@property
def last_segment(self):
"""Return the last segment of readable text since last time this property was accessed."""
text = self.text
if text and text[-1] != REPLACEMENT_CHAR:
segment = text[self.offset :]
self.offset = len(text)
return segment
return ""
class TokenizerWrapper:
"""A wrapper that combines an HF tokenizer and a detokenizer.
Accessing any attribute other than the ``detokenizer`` is forwarded to the
huggingface tokenizer.
"""
def __init__(self, tokenizer, detokenizer_class=StreamingDetokenizer):
self._tokenizer = tokenizer
self._detokenizer = detokenizer_class(tokenizer)
def __getattr__(self, attr):
if attr == "detokenizer":
return self._detokenizer
elif attr == "tokenizer":
return self._tokenizer;
else:
raise AttributeError(f"TokenizerWrapper object has no attribute {attr}")
def load_tokenizer(model_path, return_tokenizer=True, tokenizer_config_extra={}):
"""Load a huggingface tokenizer and try to infer the type of streaming
detokenizer to use.
Note, to use a fast streaming tokenizer, pass a local file path rather than
a Hugging Face repo ID.
"""
detokenizer_class = StreamingDetokenizer
tokenizer_file = model_path / "tokenizer.json"
if return_tokenizer:
return TokenizerWrapper(
AutoTokenizer.from_pretrained(model_path, **tokenizer_config_extra),
detokenizer_class,
)
else:
return detokenizer_class
import glob
import importlib
import inspect
import json
import logging
from io import BytesIO
from pathlib import Path
from textwrap import dedent
from typing import Any, List, Optional, Tuple, Union
import mlx.core as mx
import mlx.nn as nn
import numpy as np
import requests
import soundfile as sf
from huggingface_hub import snapshot_download
from mlx.utils import tree_flatten
from PIL import Image, ImageOps
from transformers import (
AutoConfig,
AutoProcessor,
PreTrainedTokenizer,
PreTrainedTokenizerFast,
)
from .models.base import BaseImageProcessor
from .tokenizer_utils import load_tokenizer
from .trainer import apply_lora_layers
# Constants
MODEL_REMAPPING = {
"llava_qwen2": "fastvlm", # Apple's FastVLM, note it's different to the one below
"llava-qwen2": "llava_bunny",
"bunny-llama": "llava_bunny",
"lfm2-vl": "lfm2_vl",
"cohere2_vision": "aya_vision",
}
MAX_FILE_SIZE_GB = 5
MODEL_CONVERSION_DTYPES = ["float16", "bfloat16", "float32"]
def skip_multimodal_module(path: str) -> bool:
"""
Check if a multimodal module (vision/audio) should skip quantization.
Args:
path: The module path to check
Returns:
bool: True if the module is multimodal and should skip quantization, False otherwise
"""
return (
"vision_model" in path
or "vision_tower" in path
or "sam_model" in path
or "audio_model" in path
or "audio_tower" in path
)
def get_model_and_args(config: dict):
"""
Retrieve the model object based on the configuration.
Args:
config (dict): The model configuration.
Returns:
A tuple containing the Model class and the ModelArgs class.
"""
model_type = config["model_type"].lower()
model_type = MODEL_REMAPPING.get(model_type, model_type)
try:
arch = importlib.import_module(f"mlx_vlm.models.{model_type}")
except ImportError:
msg = f"Model type {model_type} not supported."
logging.error(msg)
raise ValueError(msg)
return arch, model_type
def get_model_path(
path_or_hf_repo: str, revision: Optional[str] = None, force_download: bool = False
) -> Path:
"""
Ensures the model is available locally. If the path does not exist locally,
it is downloaded from the Hugging Face Hub.
Args:
path_or_hf_repo (str): The local path or Hugging Face repository ID of the model.
revision (str, optional): A revision id which can be a branch name, a tag, or a commit hash.
Returns:
Path: The path to the model.
"""
model_path = Path(path_or_hf_repo)
if not model_path.exists():
model_path = Path(
snapshot_download(
repo_id=path_or_hf_repo,
revision=revision,
allow_patterns=[
"*.json",
"*.safetensors",
"*.py",
"*.model",
"*.tiktoken",
"*.txt",
"*.jinja",
],
ignore_patterns=[
"consolidated*.safetensors",
],
force_download=force_download,
)
)
return model_path
def load_model(model_path: Path, lazy: bool = False, **kwargs) -> nn.Module:
"""
Load and initialize the model from a given path.
Args:
model_path (Path): The path to load the model from.
lazy (bool): If False eval the model parameters to make sure they are
loaded in memory before returning, otherwise they will be loaded
when needed. Default: ``False``
revision (str, optional): A revision id which can be a branch name,
a tag, or a commit hash. Default: ``None``.
Returns:
nn.Module: The loaded and initialized model.
Raises:
FileNotFoundError: If the weight files (.safetensors) are not found.
ValueError: If the model class or args class are not found or cannot be instantiated.
"""
config = load_config(model_path, **kwargs)
quantization = config.get("quantization", None)
# Find all .safetensors files in the model_path, excluding consolidated model weights
weight_files = [
wf
for wf in glob.glob(str(model_path / "*.safetensors"))
if not wf.endswith("consolidated.safetensors")
]
if not weight_files:
logging.error(f"No safetensors found in {model_path}")
message = f"""
No safetensors found in {model_path}
Create safetensors using the following code:
```
from transformers import AutoModelForCausalLM, AutoProcessor
model_id= "<huggingface_model_id>"
model = AutoModelForCausalLM.from_pretrained(model_id)
processor = AutoProcessor.from_pretrained(model_id)
model.save_pretrained("<local_dir>")
processor.save_pretrained("<local_dir>")
```
Then use the <local_dir> as the --hf-path in the convert script.
```
python -m mlx_vlm.convert --hf-path <local_dir> --mlx-path <mlx_dir>
```
"""
raise FileNotFoundError(message)
weights = {}
for wf in weight_files:
weights.update(mx.load(wf))
model_class, _ = get_model_and_args(config=config)
# Initialize text and vision configs if not present
config.setdefault("text_config", {})
config.setdefault("vision_config", {})
config.setdefault("audio_config", {})
# Initialize model config and update it with module configs
model_config = model_class.ModelConfig.from_dict(config)
modules = ["text", "vision", "perceiver", "projector", "audio"]
model_config = update_module_configs(model_config, model_class, config, modules)
model = model_class.Model(model_config)
# Sanitize weights
weights = sanitize_weights(model, weights)
weights = sanitize_weights(
model_class.VisionModel, weights, model_config.vision_config
)
weights = sanitize_weights(
model_class.LanguageModel, weights, model_config.text_config
)
if hasattr(model_class, "AudioModel"):
weights = sanitize_weights(
model_class.AudioModel, weights, model_config.audio_config
)
if (quantization := config.get("quantization", None)) is not None:
# Handle legacy models which may or may not have vision quantized
# TODO: Re-upload the models with the new quantization config and remove this
skip_vision = config.get("vision_config", {}).get("skip_vision", False)
def get_class_predicate(p, m):
# Always skip vision and audio models
if skip_multimodal_module(p) and skip_vision:
return False
# Handle custom per layer quantizations
if p in config["quantization"]:
return config["quantization"][p]
if not hasattr(m, "to_quantized"):
return False
# Skip layers not divisible by 64
if hasattr(m, "weight") and m.weight.size % 64 != 0:
return False
# Handle legacy models which may not have everything quantized
return f"{p}.scales" in weights
nn.quantize(
model,
group_size=quantization["group_size"],
bits=quantization["bits"],
class_predicate=get_class_predicate,
)
model.load_weights(list(weights.items()))
if not lazy:
mx.eval(model.parameters())
model.eval()
return model
def sanitize_weights(model_obj, weights, config=None):
"""Helper function to sanitize weights if the model has a sanitize method"""
if hasattr(model_obj, "sanitize"):
if config is not None:
model_obj = model_obj(config)
weights = model_obj.sanitize(weights)
return weights
def update_module_configs(model_config, model_class, config, modules):
"""Updates configuration for model modules like text and vision modules.
Args:
model_config: The model configuration object that will be updated
model_class: The model class containing component config classes
config: Dictionary containing configuration parameters
modules: List of module names to update configs for (e.g. ["text", "vision"])
Returns:
The updated model_config object
"""
for config_name in modules:
config_attr = f"{config_name}_config"
if hasattr(model_config, config_attr):
config_class = getattr(model_class, f"{config_name.title()}Config")
setattr(
model_config, config_attr, config_class.from_dict(config[config_attr])
)
return model_config
def load(
path_or_hf_repo: str,
adapter_path: Optional[str] = None,
lazy: bool = False,
revision: Optional[str] = None,
**kwargs,
) -> Tuple[nn.Module, Union[PreTrainedTokenizer, PreTrainedTokenizerFast]]:
"""
Load the model and tokenizer from a given path or a huggingface repository.
Args:
path_or_hf_repo (Path): The path or the huggingface repository to load the model from.
tokenizer_config (dict, optional): Configuration parameters specifically for the tokenizer.
Defaults to an empty dictionary.
adapter_path (str, optional): Path to the LoRA adapters. If provided, applies LoRA layers
to the model. Default: ``None``.
lazy (bool): If False eval the model parameters to make sure they are
loaded in memory before returning, otherwise they will be loaded
when needed. Default: ``False``
revision (str, optional): A revision id which can be a branch name,
a tag, or a commit hash. Default: ``None``.
Returns:
Tuple[nn.Module, TokenizerWrapper]: A tuple containing the loaded model and tokenizer.
Raises:
FileNotFoundError: If config file or safetensors are not found.
ValueError: If model class or args class are not found.
"""
force_download = kwargs.get("force_download", False)
model_path = get_model_path(
path_or_hf_repo, force_download=force_download, revision=revision
)
model = load_model(model_path, lazy, **kwargs)
if adapter_path is not None:
model = apply_lora_layers(model, adapter_path)
model.eval()
image_processor = load_image_processor(model_path, **kwargs)
# Get the eos_token_id from the model config
eos_token_id = getattr(model.config, "eos_token_id", None)
processor = load_processor(model_path, True, eos_token_ids=eos_token_id, **kwargs)
if image_processor is not None:
processor.image_processor = image_processor
return model, processor
def load_config(model_path: Union[str, Path], **kwargs) -> dict:
"""Load model configuration from a path or Hugging Face repo.
Args:
model_path: Local path or Hugging Face repo ID to load config from
**kwargs: Additional keyword arguments to pass to the config loader
Returns:
dict: Model configuration
Raises:
FileNotFoundError: If config.json is not found at the path
"""
if isinstance(model_path, str):
model_path = get_model_path(model_path)
try:
with open(model_path / "config.json", encoding="utf-8") as f:
return json.load(f)
except FileNotFoundError as exc:
raise FileNotFoundError(f"Config not found at {model_path}") from exc
def load_image_processor(model_path: Union[str, Path], **kwargs) -> BaseImageProcessor:
if isinstance(model_path, str):
model_path = get_model_path(model_path)
if not kwargs:
config = load_config(model_path, trust_remote_code=True)
else:
config = load_config(model_path, **kwargs)
model_class, _ = get_model_and_args(config)
image_processor = None
if hasattr(model_class, "ImageProcessor"):
init_signature = inspect.signature(model_class.ImageProcessor.__init__)
if "config" in init_signature.parameters:
image_processor = model_class.ImageProcessor(config=config)
else:
image_processor = model_class.ImageProcessor()
return image_processor
def load_processor(
model_path, add_detokenizer=True, eos_token_ids=None, **kwargs
) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]:
processor = AutoProcessor.from_pretrained(model_path, **kwargs)
if add_detokenizer:
detokenizer_class = load_tokenizer(model_path, return_tokenizer=True)
tokenizer_obj = detokenizer_class.tokenizer
# Instantiate the detokenizer
processor.tokenizer = tokenizer_obj
processor.detokenizer = detokenizer_class.detokenizer
# Determine the EOS token IDs, prioritizing the function argument
final_eos_token_ids = (
eos_token_ids if eos_token_ids is not None else tokenizer_obj.eos_token_ids
)
# Create and assign the StoppingCriteria
criteria = StoppingCriteria(final_eos_token_ids, tokenizer_obj)
if hasattr(processor, "tokenizer"):
processor.tokenizer.stopping_criteria = criteria
else:
processor.stopping_criteria = criteria
return processor
def fetch_from_hub(
model_path: Path, lazy: bool = False, **kwargs
) -> Tuple[nn.Module, dict, PreTrainedTokenizer]:
model = load_model(model_path, lazy, **kwargs)
config = load_config(model_path, **kwargs)
processor = load_processor(
model_path,
add_detokenizer=False,
eos_token_ids=config.get("eos_token_id", None),
**kwargs,
)
return model, config, processor
def make_shards(weights: dict, max_file_size_gb: int = MAX_FILE_SIZE_GB) -> list:
"""
Splits the weights into smaller shards.
Args:
weights (dict): Model weights.
max_file_size_gb (int): Maximum size of each shard in gigabytes.
Returns:
list: List of weight shards.
"""
max_file_size_bytes = max_file_size_gb << 30
shards = []
shard, shard_size = {}, 0
for k, v in weights.items():
if shard_size + v.nbytes > max_file_size_bytes:
shards.append(shard)
shard, shard_size = {}, 0
shard[k] = v
shard_size += v.nbytes
shards.append(shard)
return shards
def upload_to_hub(path: str, upload_repo: str, hf_path: str):
"""
Uploads the model to Hugging Face hub.
Args:
path (str): Local path to the model.
upload_repo (str): Name of the HF repo to upload to.
hf_path (str): Path to the original Hugging Face model.
"""
import os
from huggingface_hub import HfApi, ModelCard, logging
from . import __version__
card = ModelCard.load(hf_path)
card.data.tags = ["mlx"] if card.data.tags is None else card.data.tags + ["mlx"]
card.text = dedent(
f"""
# {upload_repo}
This model was converted to MLX format from [`{hf_path}`]() using mlx-vlm version **{__version__}**.
Refer to the [original model card](https://huggingface.co/{hf_path}) for more details on the model.
## Use with mlx
```bash
pip install -U mlx-vlm
```
```bash
python -m mlx_vlm.generate --model {upload_repo} --max-tokens 100 --temperature 0.0 --prompt "Describe this image." --image <path_to_image>
```
"""
)
card.save(os.path.join(path, "README.md"))
logging.set_verbosity_info()
api = HfApi()
api.create_repo(repo_id=upload_repo, exist_ok=True)
api.upload_folder(
folder_path=path,
repo_id=upload_repo,
repo_type="model",
)
print(f"Upload successful, go to https://huggingface.co/{upload_repo} for details.")
def apply_repetition_penalty(logits: mx.array, generated_tokens: Any, penalty: float):
"""
Apply repetition penalty to specific logits based on the given context.
Paper: https://arxiv.org/abs/1909.05858
Args:
logits (mx.array): The logits produced by the language model.
generated_tokens (any): A list of N previous tokens.
penalty (float): The repetition penalty factor to be applied.
Returns:
logits (mx.array): Logits with repetition penalty applied to generated tokens.
"""
if len(generated_tokens) > 0:
indices = mx.array([token for token in generated_tokens])
selected_logits = logits[:, indices]
selected_logits = mx.where(
selected_logits < 0, selected_logits * penalty, selected_logits / penalty
)
logits[:, indices] = selected_logits
return logits
def save_weights(
save_path: Union[str, Path],
model: nn.Module,
*,
donate_weights: bool = False,
) -> None:
"""Save model weights into specified directory."""
if isinstance(save_path, str):
save_path = Path(save_path)
weights = dict(tree_flatten(model.parameters()))
del model
save_path.mkdir(parents=True, exist_ok=True)
shards = make_shards(weights)
shards_count = len(shards)
shard_file_format = (
"model-{:05d}-of-{:05d}.safetensors"
if shards_count > 1
else "model.safetensors"
)
total_size = sum(v.nbytes for v in weights.values())
index_data = {"metadata": {"total_size": total_size}, "weight_map": {}}
# Write the weights and make sure no references are kept other than the
# necessary ones
if donate_weights:
weights.clear()
del weights
for i in range(len(shards)):
shard = shards[i]
shards[i] = None
shard_name = shard_file_format.format(i + 1, shards_count)
shard_path = save_path / shard_name
mx.save_safetensors(str(shard_path), shard, metadata={"format": "mlx"})
for weight_name in shard.keys():
index_data["weight_map"][weight_name] = shard_name
del shard
index_data["weight_map"] = {
k: index_data["weight_map"][k] for k in sorted(index_data["weight_map"])
}
with open(save_path / "model.safetensors.index.json", "w") as f:
json.dump(
index_data,
f,
indent=4,
)
def save_config(
config: dict,
config_path: Union[str, Path],
) -> None:
"""Save the model configuration to the ``config_path``.
The final configuration will be sorted before saving for better readability.
Args:
config (dict): The model configuration.
config_path (Union[str, Path]): Model configuration file path.
"""
# Clean unused keys
config.pop("_name_or_path", None)
config.pop("torch_dtype", None)
# sort the config for better readability
config = dict(sorted(config.items()))
# write the updated config to the config_path (if provided)
with open(config_path, "w") as fid:
json.dump(config, fid, indent=4)
def load_image(image_source: Union[str, Path, BytesIO], timeout: int = 10):
"""
Helper function to load an image from either a URL or file.
"""
if (
isinstance(image_source, BytesIO)
or (isinstance(image_source, str) and image_source.startswith("data:image/"))
or Path(image_source).is_file()
):
# for base64 encoded images
try:
if image_source.startswith("data:image/"):
import base64
if "," not in image_source:
raise ValueError(
"Invalid data URI format - missing comma separator"
)
_, data = image_source.split(",", 1)
image_source = BytesIO(base64.b64decode(data))
image = Image.open(image_source)
except IOError as e:
raise ValueError(
f"Failed to load image from {image_source} with error: {e}"
) from e
elif image_source.startswith(("http://", "https://")):
try:
response = requests.get(image_source, stream=True, timeout=timeout)
response.raise_for_status()
image = Image.open(response.raw)
except Exception as e:
raise ValueError(
f"Failed to load image from URL: {image_source} with error {e}"
) from e
else:
raise ValueError(
f"The image {image_source} must be a valid URL or existing file."
)
image = ImageOps.exif_transpose(image)
image = image.convert("RGB")
return image
def resize_image(img, max_size):
ratio = min(max_size[0] / img.width, max_size[1] / img.height)
new_size = (int(img.width * ratio), int(img.height * ratio))
return img.resize(new_size)
def process_image(img, resize_shape, image_processor):
if isinstance(img, str):
img = load_image(img)
if resize_shape is not None and not isinstance(image_processor, BaseImageProcessor):
img = resize_image(img, resize_shape)
return img
def resample_audio(audio: np.ndarray, orig_sr: int, target_sr: int) -> np.ndarray:
"""Resample audio using linear interpolation."""
if orig_sr == target_sr:
return audio
# Calculate the resampling ratio
ratio = target_sr / orig_sr
# Handle different audio shapes
if audio.ndim == 1:
# Mono audio - simple case
new_length = int(len(audio) * ratio)
old_indices = np.arange(len(audio))
new_indices = np.linspace(0, len(audio) - 1, new_length)
resampled = np.interp(new_indices, old_indices, audio)
elif audio.ndim == 2:
# Multi-channel audio - transpose to (samples, channels) if needed
if audio.shape[0] < audio.shape[1]:
audio = audio.T
# Resample each channel
n_samples, n_channels = audio.shape
new_length = int(n_samples * ratio)
old_indices = np.arange(n_samples)
new_indices = np.linspace(0, n_samples - 1, new_length)
resampled = np.zeros((new_length, n_channels))
for i in range(n_channels):
resampled[:, i] = np.interp(new_indices, old_indices, audio[:, i])
else:
raise ValueError(f"Audio array has unsupported shape: {audio.shape}")
return resampled
def load_audio(
file: str,
sr: int,
timeout: int = 10,
):
"""
Helper function to load audio from either a URL or file.
"""
if file.startswith(("http://", "https://")):
try:
response = requests.get(file, stream=True, timeout=timeout)
response.raise_for_status()
audio, sample_rate = sf.read(BytesIO(response.content), always_2d=True)
except Exception as e:
raise ValueError(
f"Failed to load audio from URL: {file} with error {e}"
) from e
else:
audio, sample_rate = sf.read(file, always_2d=True)
if sample_rate != sr:
audio = resample_audio(audio, sample_rate, sr)
return np.array(audio).mean(axis=1)
def process_inputs(
processor,
prompts,
images=None,
audio=None,
add_special_tokens=False,
padding=True,
padding_side="left",
return_tensors="mlx",
**kwargs,
):
# Get the process method from the processor
process_method = getattr(processor, "process", processor)
# Prepare arguments
args = {
"text": prompts,
"images": images,
"padding": padding,
"padding_side": padding_side,
"return_tensors": return_tensors,
}
# Add special tokens if supported
if "add_special_tokens" in inspect.signature(process_method).parameters:
args["add_special_tokens"] = add_special_tokens
for param in inspect.signature(process_method).parameters.keys():
if param in kwargs.keys():
args[param] = kwargs.get(param, None)
break
# Add audio if provided and supported
if audio is not None:
if "audio" in inspect.signature(process_method).parameters:
args["audio"] = audio
else:
raise ValueError(f"Processor {processor} does not support audio parameter")
return process_method(**args)
def process_inputs_with_fallback(
processor,
prompts,
images,
audio,
add_special_tokens=False,
return_tensors="mlx",
**kwargs,
):
# First attempt with specified return_tensors
try:
return process_inputs(
processor,
prompts=prompts,
images=images,
audio=audio,
add_special_tokens=add_special_tokens,
return_tensors=return_tensors,
**kwargs,
)
except Exception as e:
raise ValueError(f"Failed to process inputs with error: {e}")
def prepare_inputs(
processor,
images=None,
audio=None,
prompts=None,
image_token_index=None,
resize_shape=None,
add_special_tokens=False,
padding=True,
padding_side="left",
pad_to_uniform_size=False,
**kwargs,
):
if not images and not audio:
tokenizer = (
processor.tokenizer if hasattr(processor, "tokenizer") else processor
)
inputs = tokenizer(
prompts,
add_special_tokens=add_special_tokens,
padding=padding,
padding_side=padding_side,
)
input_ids = mx.array([inputs.input_ids])
mask = mx.array([inputs.attention_mask])
return {
"input_ids": input_ids,
"attention_mask": mask,
}
# Process images
if images is not None:
if not isinstance(images, list):
images = [images]
image_processor = (
processor.image_processor if hasattr(processor, "image_processor") else None
)
images = [process_image(img, resize_shape, image_processor) for img in images]
# Pad all images to the size of the largest
if len(images) > 1 and pad_to_uniform_size:
max_width = max(img.width for img in images)
max_height = max(img.height for img in images)
padded_images = []
for img in images:
if img.width != max_width or img.height != max_height:
# Create a new image with the max dimensions, filled with black
padded_img = Image.new(
"RGB", (max_width, max_height), (255, 255, 255)
)
# Center the original image
x_offset = (max_width - img.width) // 2
y_offset = (max_height - img.height) // 2
padded_img.paste(img, (x_offset, y_offset))
padded_images.append(padded_img)
else:
padded_images.append(img)
images = padded_images
# Process audio
if audio is not None:
if not isinstance(audio, list):
audio = [audio]
if len(audio) > 1:
print(
"\033[33mWarning\033[0m: Single prompt with multiple audio files is not supported yet. Using the first audio file.\n"
)
audio = audio[:1]
audio = [
load_audio(audio_file, sr=processor.feature_extractor.sampling_rate)
for audio_file in audio
]
model_inputs = {}
if hasattr(processor, "image_processor") and isinstance(
processor.image_processor, BaseImageProcessor
):
if not isinstance(prompts, list):
prompts = [prompts]
if processor.pad_token is None:
processor.pad_token = processor.eos_token
text_chunks = [
[processor(chunk).input_ids for chunk in prompt.split("<image>")]
for prompt in prompts
]
# Find the maximum length for padding
max_length = max(
sum(len(chunk) for chunk in chunks) + 1 for chunks in text_chunks
)
# Pad and create input_ids
input_ids = []
for chunks in text_chunks:
ids = chunks[0] + [image_token_index] + chunks[1]
padding = [processor.pad_token_id] * (max_length - len(ids))
input_ids.append(mx.array(ids + padding))
model_inputs["input_ids"] = mx.array(input_ids)
pixel_values = processor.image_processor.preprocess(images=images)
model_inputs["pixel_values"] = mx.array(np.stack(pixel_values))
model_inputs["attention_mask"] = mx.array(
[(ids != processor.pad_token_id) for ids in input_ids]
).astype(mx.int32)
else:
if hasattr(processor, "tokenizer") and processor.tokenizer.pad_token is None:
processor.tokenizer.pad_token = processor.tokenizer.eos_token
inputs = process_inputs_with_fallback(
processor,
images=images,
audio=audio,
prompts=prompts,
add_special_tokens=add_special_tokens,
**kwargs,
)
if "images" in inputs:
inputs["pixel_values"] = inputs["images"]
inputs.pop("images")
model_inputs["attention_mask"] = (
inputs["attention_mask"] if "attention_mask" in inputs else None
)
# Convert inputs to model_inputs with mx.array if present
for key, value in inputs.items():
if key not in model_inputs:
model_inputs[key] = value
return model_inputs
class StoppingCriteria:
def __init__(self, eos_token_ids: List[int], tokenizer=None):
if isinstance(eos_token_ids, int):
self.eos_token_ids = [eos_token_ids]
else:
self.eos_token_ids = eos_token_ids
self.tokenizer = tokenizer
def add_eos_token_ids(self, new_eos_token_ids: Union[int, List[int]] = None):
"""
Add new token IDs to the list of EOS token IDs.
Args:
new_eos_token_ids: Integer, string, or list of integers/strings representing token IDs to add.
If strings are provided, they will be converted to integers if possible.
"""
if new_eos_token_ids is None:
return
if self.tokenizer is None:
raise ValueError("Processor is not provided")
if new_eos_token_ids is not None:
if isinstance(new_eos_token_ids, str):
new_eos_token_ids = [new_eos_token_ids]
new_eos_token_ids = [
self.tokenizer.encode(" " + token, add_special_tokens=False)[-1]
for token in new_eos_token_ids
]
self.eos_token_ids.extend(new_eos_token_ids)
def reset(self, eos_token_ids: List[int] = None):
eos_token_ids = (
eos_token_ids if eos_token_ids is not None else self.tokenizer.eos_token_ids
)
if isinstance(eos_token_ids, int):
eos_token_ids = [eos_token_ids]
if self.eos_token_ids != eos_token_ids:
self.eos_token_ids = eos_token_ids
def __call__(self, input_ids: mx.array) -> bool:
return input_ids in self.eos_token_ids
def print_array_report(t: mx.array, label: Optional[str]) -> dict:
"""
Return a dictionary report of an MLX array similar to PyTorch's tensor representation.
Args:
arr: MLX array to analyze
Returns:
Dictionary containing shape, dtype, value representation, and statistics
"""
from pprint import pprint
# Get basic statistics
mean_val = mx.mean(t)
std_val = mx.std(t)
min_val = mx.min(t)
max_val = mx.max(t)
report = {
"shape": f"{tuple(t.shape)}",
"dtype": str(t.dtype),
"value": repr(t),
"mean": f"array({mean_val}, dtype={t.dtype})",
"std": f"array({std_val}, dtype={t.dtype})",
"min": f"array({min_val}, dtype={t.dtype})",
"max": f"array({max_val}, dtype={t.dtype})",
"label": label if label else "array",
}
# Print each field, handling 'value' specially
print("{")
for key, value in report.items():
if key == "value":
print(f" '{key}': {value},") # No quotes around value
else:
print(f" '{key}': {repr(value)},")
print("}")
return report
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment