Created
September 11, 2024 16:29
-
-
Save al-swaiti/20a38612cddce93c9b48ea05688ca6af to your computer and use it in GitHub Desktop.
GGUF-Convert-Chunks
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
import torch | |
import gguf | |
import numpy as np | |
import argparse | |
from tqdm import tqdm | |
from safetensors.torch import load_file, save_file | |
CHUNK_SIZE = 100 # Number of tensors to process in each chunk | |
QUANTIZATION_THRESHOLD = 1024 | |
REARRANGE_THRESHOLD = 512 | |
MAX_TENSOR_NAME_LENGTH = 127 | |
MODEL_DETECTION = ( | |
("flux", ( | |
("transformer_blocks.0.attn.norm_added_k.weight",), | |
("double_blocks.0.img_attn.proj.weight",), | |
)), | |
("sd3", ( | |
("transformer_blocks.0.attn.add_q_proj.weight",), | |
)), | |
("sdxl", ( | |
("down_blocks.0.downsamplers.0.conv.weight", "add_embedding.linear_1.weight",), | |
( | |
"input_blocks.3.0.op.weight", "input_blocks.6.0.op.weight", | |
"output_blocks.2.2.conv.weight", "output_blocks.5.2.conv.weight", | |
), # Non-diffusers | |
("label_emb.0.0.weight",), | |
)), | |
("sd1", ( | |
("down_blocks.0.downsamplers.0.conv.weight",), | |
( | |
"input_blocks.3.0.op.weight", "input_blocks.6.0.op.weight", "input_blocks.9.0.op.weight", | |
"output_blocks.2.1.conv.weight", "output_blocks.5.2.conv.weight", "output_blocks.8.2.conv.weight" | |
), # Non-diffusers | |
)), | |
) | |
def parse_args(): | |
parser = argparse.ArgumentParser(description="Generate F16 GGUF files from single UNET in chunks") | |
parser.add_argument("--src", required=True, help="Source model ckpt file.") | |
parser.add_argument("--dst", help="Output unet gguf file.") | |
return parser.parse_args() | |
def load_state_dict(path): | |
if any(path.endswith(x) for x in [".ckpt", ".pt", ".bin", ".pth"]): | |
state_dict = torch.load(path, map_location="cpu", weights_only=True) | |
state_dict = state_dict.get("model", state_dict) | |
else: | |
state_dict = load_file(path) | |
# only keep unet with no prefix! | |
sd = {} | |
has_prefix = any(["model.diffusion_model." in x for x in state_dict.keys()]) | |
for k, v in state_dict.items(): | |
if has_prefix and "model.diffusion_model." not in k: | |
continue | |
if has_prefix: | |
k = k.replace("model.diffusion_model.", "") | |
sd[k] = v | |
return sd | |
def detect_arch(state_dict): | |
for arch, match_lists in MODEL_DETECTION: | |
for match_list in match_lists: | |
if all(key in state_dict for key in match_list): | |
return arch | |
raise ValueError("Unknown model architecture!") | |
def load_model(path): | |
state_dict = load_state_dict(path) | |
arch = detect_arch(state_dict) | |
print(f"* Architecture detected from input: {arch}") | |
if arch == "flux" and "transformer_blocks.0.attn.norm_added_k.weight" in state_dict: | |
raise ValueError("The Diffusers UNET can not be used for this!") | |
return (arch, state_dict) | |
def handle_tensor(key, data): | |
old_dtype = data.dtype | |
if data.dtype == torch.bfloat16: | |
data = data.to(torch.float32).numpy() | |
elif data.dtype in [getattr(torch, "float8_e4m3fn", "_invalid"), getattr(torch, "float8_e5m2", "_invalid")]: | |
data = data.to(torch.float16).numpy() | |
else: | |
data = data.numpy() | |
n_dims = len(data.shape) | |
data_shape = data.shape | |
data_qtype = getattr( | |
gguf.GGMLQuantizationType, | |
"BF16" if old_dtype == torch.bfloat16 else "F16" | |
) | |
n_params = 1 | |
for dim_size in data_shape: | |
n_params *= dim_size | |
blacklist = { | |
"time_embedding.", | |
"add_embedding.", | |
"time_in.", | |
"txt_in.", | |
"vector_in.", | |
"img_in.", | |
"guidance_in.", | |
"final_layer.", | |
} | |
if old_dtype in (torch.float32, torch.bfloat16): | |
if n_dims == 1 or n_params <= QUANTIZATION_THRESHOLD or (".weight" in key and any(x in key for x in blacklist)): | |
data_qtype = gguf.GGMLQuantizationType.F32 | |
if (n_dims > 1 and n_params >= REARRANGE_THRESHOLD and (n_params / 256).is_integer() | |
and not (data.shape[-1] / 256).is_integer()): | |
orig_shape = data.shape | |
data = data.reshape(n_params // 256, 256) | |
try: | |
data = gguf.quants.quantize(data, data_qtype) | |
except (AttributeError, gguf.QuantError) as e: | |
print(f"falling back to F16 for {key}: {e}") | |
data_qtype = gguf.GGMLQuantizationType.F16 | |
data = gguf.quants.quantize(data, data_qtype) | |
return key, data, data_qtype, orig_shape if 'orig_shape' in locals() else None | |
def process_and_merge_chunks(state_dict, output_path, arch): | |
writer = gguf.GGUFWriter(output_path, arch) | |
writer.add_quantization_version(gguf.GGML_QUANT_VERSION) | |
writer.add_file_type(gguf.LlamaFileType.MOSTLY_F16) # Assuming F16 as default | |
chunks = [dict(list(state_dict.items())[i:i + CHUNK_SIZE]) for i in range(0, len(state_dict), CHUNK_SIZE)] | |
for i, chunk in enumerate(chunks): | |
for key, tensor in tqdm(chunk.items(), desc=f"Processing and merging chunk {i+1}/{len(chunks)}"): | |
new_key, quantized_data, data_qtype, orig_shape = handle_tensor(key, tensor) | |
if orig_shape is not None: | |
writer.add_array(f"comfy.gguf.orig_shape.{new_key}", tuple(int(dim) for dim in orig_shape)) | |
writer.add_tensor(new_key, quantized_data) | |
writer.write_header_to_file() | |
writer.write_kv_data_to_file() | |
writer.write_tensors_to_file(progress=True) | |
writer.close() | |
def main(): | |
args = parse_args() | |
arch, state_dict = load_model(args.src) | |
out_path = args.dst or f"{os.path.splitext(args.src)[0]}-F16.gguf" | |
process_and_merge_chunks(state_dict, out_path, arch) | |
print(f"Quantized model saved to {out_path}") | |
if __name__ == "__main__": | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment