Created
January 12, 2026 07:33
-
-
Save Extraltodeus/829ca804d355a37dca7bd134f5f80c9d to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import safetensors | |
| import safetensors.torch | |
| import torch | |
| from tqdm import tqdm | |
| model_path = "./amazing_but_too_big_qwen3_4b_for_z_image.safetensors" | |
| save_path = "./super_mario_precision_result.safetensors" | |
| out_dtype = torch.float8_e4m3fn | |
| out_dtype_max = torch.finfo(out_dtype).max | |
| eps16 = torch.finfo(torch.float16).eps | |
| def get_keys(path, device="cpu"): | |
| try: | |
| with safetensors.safe_open(path, framework="pt", device=device) as f: | |
| return f.keys() | |
| except Exception as e: | |
| print(f"\n{e}\nPath: {path}") | |
| return None | |
| def get_layer(path, key, device="cpu"): | |
| try: | |
| with safetensors.safe_open(path, framework="pt", device=device) as f: | |
| if key in f.keys(): | |
| return f.get_tensor(key).to(copy=True) | |
| except Exception as e: | |
| print(f"\n{e}\nPath: {path}") | |
| return None | |
| def fp8_e4m3fn_quantize_layers(layer, layer_key): | |
| r_out = {} | |
| layer = layer.to(torch.float32) | |
| mxabs = layer.abs().max() | |
| if mxabs == 0 or not torch.isfinite(mxabs): | |
| s = torch.tensor(1.0, device=layer.device, dtype=torch.float32) | |
| q = torch.zeros_like(layer, dtype=out_dtype) | |
| else: | |
| s = (mxabs / (out_dtype_max)).to(torch.float32) | |
| q = layer.div(s + torch.finfo(layer.dtype).eps).clamp(-out_dtype_max, out_dtype_max).to(dtype=out_dtype) | |
| scale_key = ".".join(layer_key.split(".")[:-1]) + f".scale_{layer_key.split('.')[-1]}" | |
| input_key = ".".join(layer_key.split(".")[:-1]) + f".scale_input" | |
| r_out[layer_key] = q.contiguous().cpu() | |
| r_out[scale_key] = s.contiguous().cpu() | |
| r_out[input_key] = torch.tensor([eps16]) | |
| return r_out | |
| quantized_layers = {} | |
| new_model = {} | |
| p = model_path | |
| tkl = "model.embed_tokens.weight" | |
| keys = get_keys(p) | |
| lk = len(keys) | |
| tot_quant = 0 | |
| pb = tqdm(total=lk, desc="total layers quantized: 0") | |
| for k in keys: | |
| l = get_layer(p, k, 0 if k != tkl else "cpu") | |
| scale_key = ".".join(k.split(".")[:-1]) + f".scale_{k.split('.')[-1]}" | |
| if k != tkl and k.endswith(".weight") and l.numel() > 100000: | |
| new_model.update(fp8_e4m3fn_quantize_layers(l, k)) | |
| quantized_layers[k] = str(out_dtype) | |
| tot_quant += 1 | |
| pb.desc = f"total layers quantized: {tot_quant}" | |
| else: | |
| new_model[k] = l.to(torch.bfloat16) | |
| pb.update(1) | |
| pb.close() | |
| if tot_quant > 0: | |
| new_model['scaled_fp8'] = torch.tensor([1.]) | |
| print("Saving...") | |
| safetensors.torch.save_file(new_model, save_path) | |
| print("Done!", f"total layers quantized: {tot_quant}") |
Author
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
(I forgot to remov the quantized_layers dict) (which is useless)