Skip to content

Instantly share code, notes, and snippets.

@jnulzl
Created February 27, 2025 07:01
Show Gist options
  • Save jnulzl/b5f087aaddb2ceb259e44afb80efea5a to your computer and use it in GitHub Desktop.
Save jnulzl/b5f087aaddb2ceb259e44afb80efea5a to your computer and use it in GitHub Desktop.
Qwen2.5-VL GPTQ量化教程
import os
import sys
from transformers import AutoTokenizer, AutoProcessor
from qwen_vl_utils import process_vision_info
def main(model_dir, messages):
if not os.path.exists(model_dir):
raise Exception("%s directory don't exists!"%(model_dir))
if "Qwen2-VL" in model_dir:
from transformers import Qwen2VLForConditionalGeneration as QwenVLForConditionalGeneration
elif "Qwen2.5-VL" in model_dir:
from transformers import Qwen2_5_VLForConditionalGeneration as QwenVLForConditionalGeneration
print(QwenVLForConditionalGeneration)
# default: Load the model on the available device(s)
model = QwenVLForConditionalGeneration.from_pretrained(
model_dir,
attn_implementation="flash_attention_2",
device_map="auto"
)
# We recommend enabling flash_attention_2 for better acceleration and memory saving, especially in multi-image and video scenarios.
# model = QwenVLForConditionalGeneration.from_pretrained(
# model_dir,
# torch_dtype=torch.bfloat16,
# attn_implementation="flash_attention_2",
# device_map="auto",
# )
# default processer
processor = AutoProcessor.from_pretrained(model_dir)
# The default range for the number of visual tokens per image in the model is 4-16384. You can set min_pixels and max_pixels according to your needs, such as a token count range of 256-1280, to balance speed and memory usage.
# min_pixels = 256*28*28
# max_pixels = 1280*28*28
# processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-2B-Instruct-GPTQ-Int8", min_pixels=min_pixels, max_pixels=max_pixels)
# Preparation for inference
text = processor.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
image_inputs, video_inputs = process_vision_info(messages)
inputs = processor(
text=[text],
images=image_inputs,
videos=video_inputs,
padding=True,
return_tensors="pt",
).to("cuda")
# Inference: Generation of the output
generated_ids = model.generate(**inputs, max_new_tokens=128)
generated_ids_trimmed = [
out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
]
output_text = processor.batch_decode(
generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
)
# print(output_text)
return output_text
if __name__ == "__main__":
if 2 != len(sys.argv):
print("Usage:\n\t python %s model_dir"%(sys.argv[0]))
sys.exit(-1)
messages = [
{
"role": "user",
"content": [
{
"type": "image",
"image": "https://raw.githubusercontent.com/ymcui/Chinese-LLaMA-Alpaca-3/refs/heads/main/pics/banner.png",
},
{"type": "text", "text": "请描述一下这张图."},
],
}
]
model_dir = sys.argv[1]
output_text = main(model_dir, messages)
print(output_text)
import os
import sys
import logging
import torch
from gptqmodel import GPTQModel, QuantizeConfig, get_best_device
from transformers import AutoTokenizer
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
# Set up logging
logging.basicConfig(
format="%(asctime)s %(levelname)s [%(name)s] %(message)s",
level=logging.INFO,
datefmt="%Y-%m-%d %H:%M:%S",
)
class QwenVLQuant:
def __init__(self, model_path, bits):
self.tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=True)
self.quantize_config = QuantizeConfig(
bits=bits, # quantize model to 4-bit
group_size=128, # it is recommended to set the value to 128
)
# load un-quantized model, by default, the model will always be loaded into CPU memory
self.model = GPTQModel.load(model_path, self.quantize_config)
@classmethod
def __load_dataset(cls, label_path):
# Then you need to prepare your data for calibaration. What you need to do is just put samples into a list,
# each of which is a typical chat message as shown below. you can specify text and image in `content` field:
# dataset = [
# # message 0
# [
# {"role": "system", "content": "You are a helpful assistant."},
# {"role": "user", "content": "Tell me who you are."},
# {"role": "assistant", "content": "I am a large language model named Qwen..."},
# ],
# # message 1
# [
# {
# "role": "user",
# "content": [
# {"type": "image", "image": "file:///path/to/your/image.jpg"},
# {"type": "text", "text": "Output all text in the image"},
# ],
# },
# {"role": "assistant", "content": "The text in the image is balabala..."},
# ],
# # other messages...
# ...,
# ]
# here, we use a caption dataset **only for demonstration**. You should replace it with your own sft dataset.
with open(label_path, "r") as fpR:
lines = fpR.readlines()
dataset = []
img_root = os.path.dirname(label_path)
for line in lines:
line = line.split(",")
img_path = os.path.join(img_root, line[0])
if not os.path.exists(img_path):
continue
tmp = {}
tmp["url"] = img_path
tmp["caption"] = line[1]
dataset.append(tmp)
'''
dataset:
[
{"url":IMG_PATH1, "caption":CAPTION1},
{"url":IMG_PATH2, "caption":CAPTION2},
{"url":IMG_PATH3, "caption":CAPTION3},
......
]
'''
return dataset
@classmethod
def __prepare_dataset(cls, label_path, n_sample=512) -> list[list[dict]]:
dataset = QwenVLQuant.__load_dataset(label_path)
sample_num = min(n_sample, len(dataset))
dataset = dataset[:sample_num]
return [
[
{
"role": "user",
"content": [
{"type": "image", "image": sample["url"]},
{"type": "text", "text": "generate a caption for this image"},
],
},
{"role": "assistant", "content": sample["caption"]},
]
for sample in dataset
]
def quant_model(self, label_path):
calibration_dataset = QwenVLQuant.__prepare_dataset(label_path)
print("dataset num : ", len(calibration_dataset))
# quantize model, the calibration_dataset should be list of dict whose keys can only be "input_ids" and "attention_mask"
self.model.quantize(calibration_dataset)
def save_quanted_model(self, quant_path):
# Finally, save the quantized model:
self.model.save(quant_path)
# push quantized model to Hugging Face Hub.
# to use use_auth_token=True, Login first via huggingface-cli login.
# or pass explcit token with: use_auth_token="hf_xxxxxxx"
# (uncomment the following three lines to enable this feature)
# repo_id = f"YourUserName/{quantized_model_dir}"
# commit_message = f"GPTQModel model for {pretrained_model_dir}: {quantize_config.bits}bits, gr{quantize_config.group_size}, desc_act={quantize_config.desc_act}"
# model.push_to_hub(repo_id, commit_message=commit_message, use_auth_token=True)
# alternatively you can save and push at the same time
# (uncomment the following three lines to enable this feature)
# repo_id = f"YourUserName/{quantized_model_dir}"
# commit_message = f"GPTQModel model for {pretrained_model_dir}: {quantize_config.bits}bits, gr{quantize_config.group_size}, desc_act={quantize_config.desc_act}"
# model.push_to_hub(repo_id, save_dir=quantized_model_dir, commit_message=commit_message, use_auth_token=True)
# save quantized model using safetensors
# model.save(quant_path)
# # load quantized model to the first GPU
# device = get_best_device()
# model = GPTQModel.load(quant_path, device=device)
# # load quantized model to CPU with IPEX kernel linear.
# # model = GPTQModel.from_quantized(quantized_model_dir, device="cpu")
# # download quantized model from Hugging Face Hub and load to the first GPU
# # model = GPTQModel.from_quantized(repo_id, device="cuda:0",)
# # inference with model.generate
# print(tokenizer.decode(model.generate(**tokenizer("gptqmodel is", return_tensors="pt").to(model.device))[0]))
if __name__ == "__main__":
if 4 != len(sys.argv):
print("Usage:\n\t python %s model_path bits label_path"%(sys.argv[0]))
sys.exit(-1)
model_path = sys.argv[1]
bits = int(sys.argv[2])
label_path = sys.argv[3]
quant_obj = QwenVLQuant(model_path, bits)
quant_obj.quant_model(label_path)
quant_path = model_path + "-GPTQModel-jnulzl-int%d"%(bits)
quant_obj.save_quanted_model(quant_path)

Qwen2.5-VL GPTQ量化教程

环境

  • 系统&软件:Ubuntu 20.04 + GTX 3090 + CUDA 11.8 + Python 3.10

依赖库

安装以下依赖库

accelerate==1.4.0
aiohappyeyeballs==2.4.6
aiohttp==3.11.13
aiosignal==1.3.2
async-timeout==5.0.1
attrs==25.1.0
av==14.2.0
certifi==2025.1.31
charset-normalizer==3.4.1
datasets==3.3.2
decord==0.6.0
device-smi==0.4.0
dill==0.3.8
einops==0.8.1
filelock==3.17.0
flash-attn==2.7.3 # flash_attn-2.7.3+cu11torch2.4cxx11abiFALSE-cp310-cp310-linux_x86_64.whl
frozenlist==1.5.0
fsspec==2024.12.0
gptqmodel==1.9.0+cu118torch2.4 # gptqmodel-1.9.0+cu118torch2.4-cp310-cp310-linux_x86_64.whl
huggingface-hub==0.29.1
idna==3.10
Jinja2==3.1.5
MarkupSafe==3.0.2
mpmath==1.3.0
multidict==6.1.0
multiprocess==0.70.16
networkx==3.4.2
numpy==2.2.3
nvidia-cublas-cu11==11.11.3.6
nvidia-cuda-cupti-cu11==11.8.87
nvidia-cuda-nvrtc-cu11==11.8.89
nvidia-cuda-runtime-cu11==11.8.89
nvidia-cudnn-cu11==9.1.0.70
nvidia-cufft-cu11==10.9.0.58
nvidia-curand-cu11==10.3.0.86
nvidia-cusolver-cu11==11.4.1.48
nvidia-cusparse-cu11==11.7.5.86
nvidia-nccl-cu11==2.20.5
nvidia-nvtx-cu11==11.8.86
optimum==1.24.0
packaging==24.2
pandas==2.2.3
pillow==11.1.0
propcache==0.3.0
psutil==7.0.0
pyarrow==19.0.1
python-dateutil==2.9.0.post0
pytz==2025.1
PyYAML==6.0.2
qwen-vl-utils==0.0.8
regex==2024.11.6
requests==2.32.3
safetensors==0.5.3
six==1.17.0
sympy==1.13.3
threadpoolctl==3.5.0
tokenicer==0.0.4
tokenizers==0.21.0
torch==2.4.1+cu118 # torch-2.4.1+cu118-cp310-cp310-linux_x86_64.whl
torchaudio=2.4.1+cu118 # torchaudio-2.4.1+cu118-cp310-cp310-linux_x86_64.whl
torchvision=0.19.1+cu118 # torchvision-0.19.1+cu118-cp310-cp310-linux_x86_64.whl
tqdm==4.67.1
transformers==4.49.0
triton==3.0.0
typing_extensions==4.12.2
tzdata==2025.1
urllib3==2.3.0
xxhash==3.5.0
yarl==1.18.3

修改GPTQModel(!!!)

diff --git a/gptqmodel/models/_const.py b/gptqmodel/models/_const.py
index 08341897..c851d3f0 100644
--- a/gptqmodel/models/_const.py
+++ b/gptqmodel/models/_const.py
@@ -160,6 +160,7 @@ SUPPORTED_MODELS = [
     "minicpm3",
     "qwen2_moe",
     "qwen2_vl",
+    "qwen2_5_vl",
     "dbrx_converted",
     "deepseek_v2",
     "deepseek_v3",
diff --git a/gptqmodel/models/auto.py b/gptqmodel/models/auto.py
index 1a5bd65c..6dd7f949 100644
--- a/gptqmodel/models/auto.py
+++ b/gptqmodel/models/auto.py
@@ -100,6 +100,7 @@ from .definitions.qwen import QwenGPTQ  # noqa: E402
 from .definitions.qwen2 import Qwen2GPTQ  # noqa: E402
 from .definitions.qwen2_moe import Qwen2MoeGPTQ  # noqa: E402
 from .definitions.qwen2_vl import Qwen2VLGPTQ  # noqa: E402
+from .definitions.qwen2_5_vl import Qwen2_5_VLGPTQ  # noqa: E402
 from .definitions.rw import RWGPTQ  # noqa: E402
 from .definitions.stablelmepoch import StableLMEpochGPTQ  # noqa: E402
 from .definitions.starcoder2 import Starcoder2GPTQ  # noqa: E402
@@ -153,6 +154,7 @@ MODEL_MAP = {
     "minicpm3": MiniCPM3GPTQ,
     "qwen2_moe": Qwen2MoeGPTQ,
     "qwen2_vl": Qwen2VLGPTQ,
+    "qwen2_5_vl": Qwen2_5_VLGPTQ,
     "dbrx": DbrxGPTQ,
     "dbrx_converted": DbrxConvertedGPTQ,
     "deepseek_v2": DeepSeekV2GPTQ,
diff --git a/gptqmodel/models/definitions/__init__.py b/gptqmodel/models/definitions/__init__.py
index 922e15b4..1ce3bc9d 100644
--- a/gptqmodel/models/definitions/__init__.py
+++ b/gptqmodel/models/definitions/__init__.py
@@ -56,6 +56,7 @@ from .qwen import QwenGPTQ
 from .qwen2 import Qwen2GPTQ
 from .qwen2_moe import Qwen2MoeGPTQ
 from .qwen2_vl import Qwen2VLGPTQ
+from .qwen2_5_vl import Qwen2_5_VLGPTQ
 from .rw import RWGPTQ
 from .stablelmepoch import StableLMEpochGPTQ
 from .starcoder2 import Starcoder2GPTQ
diff --git a/tests/models/ovis/image_to_test_dataset.py b/tests/models/ovis/image_to_test_dataset.py
index 63ad09b0..e8daa43b 100644

## 以下只针对GPTQModel v1.9.0的源码
--- a/tests/models/ovis/image_to_test_dataset.py
+++ b/tests/models/ovis/image_to_test_dataset.py
@@ -14,7 +14,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from gptqmodel.models import OvisGPTQ, Qwen2VLGPTQ
+from gptqmodel.models import OvisGPTQ, Qwen2VLGPTQ, Qwen2_5_VLGPTQ
 
 
 def format_ovis_dataset(image, assistant):
@@ -65,4 +65,7 @@ def get_calib_dataset(model):
     if isinstance(model, Qwen2VLGPTQ):
         return prepare_dataset(format_qwen2_vl_dataset, n_sample=1)
 
+    if isinstance(model, Qwen2_5_VLGPTQ):
+        return prepare_dataset(format_qwen2_vl_dataset, n_sample=1)
+
     raise NotImplementedError(f"Unsupported MODEL: {model.__class__}")

qwen2_5_vl.py拷贝到XXXXX/python3.10/site-packages/gptqmodel/models/definitions目录下面

量化

  • 量化数据
coco2017
├── val2017
├── val.csv

其中val.csv内容如下:

...
val2017/000000037777.jpg,"a kitchen with wooden cabinets on the walls, a stove, multiple drawers, a refrigerator, a counter with fruits, and a well-organized layout for cooking and storage needs."
val2017/000000087038.jpg, "multiple people wearing sweatshirts, a person on a bicycle performing tricks, and another person mid-jump off a skateboarding ramp. The backdrop consists of buildings with graffiti artworks, adding a vibrant feel to the urban setting. The image appears to be set in an urban skate park or a designated area for extreme sports within a city." 
...
  • 开始量化
python quant_qwenvl_gptqmodel.py Qwen/Qwen2.5-VL-3B-Instruct 8 dataset/coco2017/val.csv

量化demo

python demo_qwen_vl.py Qwen/Qwen2.5-VL-3B-Instruct-GPTQModel-jnulzl-int4
# Copyright 2024-2025 ModelCloud.ai
# Copyright 2024-2025 [email protected]
# Contact: [email protected], x.com/qubitium
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os.path
import shutil
from typing import Dict, Optional
from PIL import Image
from transformers import AutoModelForVision2Seq, AutoProcessor, AutoTokenizer
from ...utils.calibration import batched
from ...utils.image import extract_vision_info, fetch_image
from ...utils.model import MODALITY, move_to
from .._const import CPU
from ..base import BaseGPTQModel
class Qwen2_5_VLGPTQ(BaseGPTQModel):
loader = AutoModelForVision2Seq
base_modules = ["model.embed_tokens", "model.norm"]
pre_lm_head_norm_module = "model.norm"
layers_node = "model.layers"
layer_type = "Qwen2_5_VLDecoderLayer"
layer_modules = [
["self_attn.k_proj", "self_attn.v_proj", "self_attn.q_proj"],
["self_attn.o_proj"],
["mlp.up_proj", "mlp.gate_proj"],
["mlp.down_proj"],
]
modality = [MODALITY.TEXT, MODALITY.IMAGE_TO_TEXT]
require_load_processor = True
quant_override_files = {
"preprocessor_config.json": {
"do_convert_rgb": True,
"do_normalize": True,
"do_rescale": True,
"do_resize": True,
"image_mean": [
0.48145466,
0.4578275,
0.40821073
],
"image_processor_type": "Qwen2VLImageProcessor",
"image_std": [
0.26862954,
0.26130258,
0.27577711
],
"max_pixels": 12845056,
"merge_size": 2,
"min_pixels": 3136,
"patch_size": 14,
"processor_class": "Qwen2_5_VLProcessor",
"resample": 3,
"rescale_factor": 0.00392156862745098,
"size": {
"max_pixels": 1003520,
"min_pixels": 3136
},
"temporal_patch_size": 2,
"vision_token_id": 151654
}
}
def pre_quantize_generate_hook_start(self):
self.model.visual = move_to(self.model.visual, device=self.quantize_config.device)
def pre_quantize_generate_hook_end(self):
self.model.visual = move_to(self.model.visual, device=CPU)
@staticmethod
def process_vision_info(
conversations: list[dict] | list[list[dict]],
) -> Optional[list[Image.Image]]:
vision_infos = extract_vision_info(conversations)
# Read images
image_inputs = []
for vision_info in vision_infos:
if "image" in vision_info or "image_url" in vision_info:
image_inputs.append(fetch_image(vision_info))
else:
raise ValueError("image, image_url should in content.")
if len(image_inputs) == 0:
image_inputs = None
return image_inputs
def preprocess_dataset(self, sample: Dict) -> Dict:
return sample
def prepare_dataset(
self,
calibration_dataset,
calibration_dataset_concat_size,
batch_size: int = 1,
tokenizer=None, ):
import json
import tempfile
if tokenizer is None:
tokenizer = AutoTokenizer.from_pretrained(self.model_local_path)
with tempfile.TemporaryDirectory() as tmp_dir:
chat_template_file = os.path.join(self.model_local_path, "chat_template.json")
if os.path.exists(chat_template_file):
shutil.copyfile(chat_template_file, os.path.join(tmp_dir, "chat_template.json"))
tokenizer.save_pretrained(tmp_dir)
with open(os.path.join(tmp_dir, "preprocessor_config.json"), "w") as f:
f.write(json.dumps(self.quant_override_files["preprocessor_config.json"]))
processor = AutoProcessor.from_pretrained(tmp_dir)
calib_data = []
for batch in batched(calibration_dataset, batch_size, process_func=self.preprocess_dataset):
text = processor.apply_chat_template(
batch, tokenize=False, add_generation_prompt=True
)
image_inputs = self.process_vision_info(batch)
inputs = processor(
text=text,
images=image_inputs,
videos=None,
padding=True,
return_tensors="pt",
)
calib_data.append(inputs)
del processor
return calib_data
@jnulzl
Copy link
Author

jnulzl commented Feb 27, 2025

qwen2_5_vl-py基于qwen2_vl.py略微修改而来

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment