Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Select an option

  • Save Extraltodeus/829ca804d355a37dca7bd134f5f80c9d to your computer and use it in GitHub Desktop.

Select an option

Save Extraltodeus/829ca804d355a37dca7bd134f5f80c9d to your computer and use it in GitHub Desktop.
import safetensors
import safetensors.torch
import torch
from tqdm import tqdm
model_path = "./amazing_but_too_big_qwen3_4b_for_z_image.safetensors"
save_path = "./super_mario_precision_result.safetensors"
out_dtype = torch.float8_e4m3fn
out_dtype_max = torch.finfo(out_dtype).max
eps16 = torch.finfo(torch.float16).eps
def get_keys(path, device="cpu"):
try:
with safetensors.safe_open(path, framework="pt", device=device) as f:
return f.keys()
except Exception as e:
print(f"\n{e}\nPath: {path}")
return None
def get_layer(path, key, device="cpu"):
try:
with safetensors.safe_open(path, framework="pt", device=device) as f:
if key in f.keys():
return f.get_tensor(key).to(copy=True)
except Exception as e:
print(f"\n{e}\nPath: {path}")
return None
def fp8_e4m3fn_quantize_layers(layer, layer_key):
r_out = {}
layer = layer.to(torch.float32)
mxabs = layer.abs().max()
if mxabs == 0 or not torch.isfinite(mxabs):
s = torch.tensor(1.0, device=layer.device, dtype=torch.float32)
q = torch.zeros_like(layer, dtype=out_dtype)
else:
s = (mxabs / (out_dtype_max)).to(torch.float32)
q = layer.div(s + torch.finfo(layer.dtype).eps).clamp(-out_dtype_max, out_dtype_max).to(dtype=out_dtype)
scale_key = ".".join(layer_key.split(".")[:-1]) + f".scale_{layer_key.split('.')[-1]}"
input_key = ".".join(layer_key.split(".")[:-1]) + f".scale_input"
r_out[layer_key] = q.contiguous().cpu()
r_out[scale_key] = s.contiguous().cpu()
r_out[input_key] = torch.tensor([eps16])
return r_out
quantized_layers = {}
new_model = {}
p = model_path
tkl = "model.embed_tokens.weight"
keys = get_keys(p)
lk = len(keys)
tot_quant = 0
pb = tqdm(total=lk, desc="total layers quantized: 0")
for k in keys:
l = get_layer(p, k, 0 if k != tkl else "cpu")
scale_key = ".".join(k.split(".")[:-1]) + f".scale_{k.split('.')[-1]}"
if k != tkl and k.endswith(".weight") and l.numel() > 100000:
new_model.update(fp8_e4m3fn_quantize_layers(l, k))
quantized_layers[k] = str(out_dtype)
tot_quant += 1
pb.desc = f"total layers quantized: {tot_quant}"
else:
new_model[k] = l.to(torch.bfloat16)
pb.update(1)
pb.close()
if tot_quant > 0:
new_model['scaled_fp8'] = torch.tensor([1.])
print("Saving...")
safetensors.torch.save_file(new_model, save_path)
print("Done!", f"total layers quantized: {tot_quant}")
@Extraltodeus
Copy link
Copy Markdown
Author

(I forgot to remov the quantized_layers dict) (which is useless)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment