-
-
Save kiyoon/ffb1af1e59be8e2802c1cd6ca3d9d1bc to your computer and use it in GitHub Desktop.
script for comparing the contents of safetensors files
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python3 | |
from pathlib import Path | |
from safetensors.torch import load_file | |
def summarize_tensor(x): | |
if x is None: | |
return "None" | |
x = x.float() | |
return f"({x.min().item():.3f}, {x.mean().item():.3f}, {x.max().item():.3f})" | |
def compare_keys(dev_keys, ref_keys, dev_name, ref_name): | |
out = f"\n{ref_name} has {len(ref_keys)} keys; {dev_name} has {len(dev_keys)} keys" | |
out += f"\nkeys in {ref_name} but not in {dev_name}: {ref_keys - dev_keys}" | |
out += f"\nkeys in {dev_name} but not in {ref_name}: {dev_keys - ref_keys}" | |
return out | |
def main(a_path, b_path): | |
a_path, b_path = Path(a_path), Path(b_path) | |
assert a_path.exists() | |
assert b_path.exists() | |
a_st = load_file(a_path) | |
b_st = load_file(b_path) | |
print(compare_keys(a_st.keys(), b_st.keys(), a_path, b_path)) | |
all_keys = sorted(list(a_st.keys() | b_st.keys())) | |
key_col_width = max(len(k) for k in all_keys) + 1 | |
for k in all_keys: | |
a_val = a_st.get(k, None) | |
b_val = b_st.get(k, None) | |
if a_val is not None and b_val is not None and (a_val == b_val).all(): | |
print(f"{k.ljust(key_col_width)} \033[37mIdentical\033[0m") | |
else: | |
diff = f"\033[34m{summarize_tensor(a_val).ljust(32)} \033[30m->\033[0m \033[36m{summarize_tensor(b_val).ljust(32)}\033[0m" | |
if a_val is not None and b_val is not None: | |
net_change = b_val.std() / a_val.std().add(1e-8) | |
net_change_str = f"{net_change.item():.4f}x" | |
else: | |
net_change = 1.0 | |
net_change_str = "" | |
if net_change > 1.5: | |
net_change_str = f"\033[31m{net_change_str}\033[0m" | |
elif net_change < 0.5: | |
net_change_str = f"\033[32m{net_change_str}\033[0m" | |
else: | |
net_change_str = f"\033[30m{net_change_str}\033[0m" | |
print(f"{k.ljust(key_col_width)} {diff} {net_change_str}") | |
if __name__ == "__main__": | |
import sys | |
assert len(sys.argv[1:]) == 2, f"Try: {sys.argv[0]} a.safetensors b.safetensors" | |
main(*sys.argv[1:]) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment