Skip to content

Instantly share code, notes, and snippets.

@al-swaiti
Created September 11, 2024 16:29
Show Gist options
  • Save al-swaiti/20a38612cddce93c9b48ea05688ca6af to your computer and use it in GitHub Desktop.
Save al-swaiti/20a38612cddce93c9b48ea05688ca6af to your computer and use it in GitHub Desktop.
GGUF-Convert-Chunks
import os
import torch
import gguf
import numpy as np
import argparse
from tqdm import tqdm
from safetensors.torch import load_file, save_file
CHUNK_SIZE = 100 # Number of tensors to process in each chunk
QUANTIZATION_THRESHOLD = 1024
REARRANGE_THRESHOLD = 512
MAX_TENSOR_NAME_LENGTH = 127
MODEL_DETECTION = (
("flux", (
("transformer_blocks.0.attn.norm_added_k.weight",),
("double_blocks.0.img_attn.proj.weight",),
)),
("sd3", (
("transformer_blocks.0.attn.add_q_proj.weight",),
)),
("sdxl", (
("down_blocks.0.downsamplers.0.conv.weight", "add_embedding.linear_1.weight",),
(
"input_blocks.3.0.op.weight", "input_blocks.6.0.op.weight",
"output_blocks.2.2.conv.weight", "output_blocks.5.2.conv.weight",
), # Non-diffusers
("label_emb.0.0.weight",),
)),
("sd1", (
("down_blocks.0.downsamplers.0.conv.weight",),
(
"input_blocks.3.0.op.weight", "input_blocks.6.0.op.weight", "input_blocks.9.0.op.weight",
"output_blocks.2.1.conv.weight", "output_blocks.5.2.conv.weight", "output_blocks.8.2.conv.weight"
), # Non-diffusers
)),
)
def parse_args():
parser = argparse.ArgumentParser(description="Generate F16 GGUF files from single UNET in chunks")
parser.add_argument("--src", required=True, help="Source model ckpt file.")
parser.add_argument("--dst", help="Output unet gguf file.")
return parser.parse_args()
def load_state_dict(path):
if any(path.endswith(x) for x in [".ckpt", ".pt", ".bin", ".pth"]):
state_dict = torch.load(path, map_location="cpu", weights_only=True)
state_dict = state_dict.get("model", state_dict)
else:
state_dict = load_file(path)
# only keep unet with no prefix!
sd = {}
has_prefix = any(["model.diffusion_model." in x for x in state_dict.keys()])
for k, v in state_dict.items():
if has_prefix and "model.diffusion_model." not in k:
continue
if has_prefix:
k = k.replace("model.diffusion_model.", "")
sd[k] = v
return sd
def detect_arch(state_dict):
for arch, match_lists in MODEL_DETECTION:
for match_list in match_lists:
if all(key in state_dict for key in match_list):
return arch
raise ValueError("Unknown model architecture!")
def load_model(path):
state_dict = load_state_dict(path)
arch = detect_arch(state_dict)
print(f"* Architecture detected from input: {arch}")
if arch == "flux" and "transformer_blocks.0.attn.norm_added_k.weight" in state_dict:
raise ValueError("The Diffusers UNET can not be used for this!")
return (arch, state_dict)
def handle_tensor(key, data):
old_dtype = data.dtype
if data.dtype == torch.bfloat16:
data = data.to(torch.float32).numpy()
elif data.dtype in [getattr(torch, "float8_e4m3fn", "_invalid"), getattr(torch, "float8_e5m2", "_invalid")]:
data = data.to(torch.float16).numpy()
else:
data = data.numpy()
n_dims = len(data.shape)
data_shape = data.shape
data_qtype = getattr(
gguf.GGMLQuantizationType,
"BF16" if old_dtype == torch.bfloat16 else "F16"
)
n_params = 1
for dim_size in data_shape:
n_params *= dim_size
blacklist = {
"time_embedding.",
"add_embedding.",
"time_in.",
"txt_in.",
"vector_in.",
"img_in.",
"guidance_in.",
"final_layer.",
}
if old_dtype in (torch.float32, torch.bfloat16):
if n_dims == 1 or n_params <= QUANTIZATION_THRESHOLD or (".weight" in key and any(x in key for x in blacklist)):
data_qtype = gguf.GGMLQuantizationType.F32
if (n_dims > 1 and n_params >= REARRANGE_THRESHOLD and (n_params / 256).is_integer()
and not (data.shape[-1] / 256).is_integer()):
orig_shape = data.shape
data = data.reshape(n_params // 256, 256)
try:
data = gguf.quants.quantize(data, data_qtype)
except (AttributeError, gguf.QuantError) as e:
print(f"falling back to F16 for {key}: {e}")
data_qtype = gguf.GGMLQuantizationType.F16
data = gguf.quants.quantize(data, data_qtype)
return key, data, data_qtype, orig_shape if 'orig_shape' in locals() else None
def process_and_merge_chunks(state_dict, output_path, arch):
writer = gguf.GGUFWriter(output_path, arch)
writer.add_quantization_version(gguf.GGML_QUANT_VERSION)
writer.add_file_type(gguf.LlamaFileType.MOSTLY_F16) # Assuming F16 as default
chunks = [dict(list(state_dict.items())[i:i + CHUNK_SIZE]) for i in range(0, len(state_dict), CHUNK_SIZE)]
for i, chunk in enumerate(chunks):
for key, tensor in tqdm(chunk.items(), desc=f"Processing and merging chunk {i+1}/{len(chunks)}"):
new_key, quantized_data, data_qtype, orig_shape = handle_tensor(key, tensor)
if orig_shape is not None:
writer.add_array(f"comfy.gguf.orig_shape.{new_key}", tuple(int(dim) for dim in orig_shape))
writer.add_tensor(new_key, quantized_data)
writer.write_header_to_file()
writer.write_kv_data_to_file()
writer.write_tensors_to_file(progress=True)
writer.close()
def main():
args = parse_args()
arch, state_dict = load_model(args.src)
out_path = args.dst or f"{os.path.splitext(args.src)[0]}-F16.gguf"
process_and_merge_chunks(state_dict, out_path, arch)
print(f"Quantized model saved to {out_path}")
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment