Created
July 29, 2023 02:36
-
-
Save madebyollin/034afe6670fc03966d075912cbccf797 to your computer and use it in GitHub Desktop.
script for comparing the contents of safetensors files
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python3 | |
from pathlib import Path | |
from safetensors.torch import load_file | |
def summarize_tensor(x): | |
if x is None: | |
return "None" | |
x = x.float() | |
return f"({x.min().item():.3f}, {x.mean().item():.3f}, {x.max().item():.3f})" | |
def compare_keys(dev_keys, ref_keys, dev_name, ref_name): | |
out = f"\n{ref_name} has {len(ref_keys)} keys; {dev_name} has {len(dev_keys)} keys" | |
out += f"\nkeys in {ref_name} but not in {dev_name}: {ref_keys - dev_keys}" | |
out += f"\nkeys in {dev_name} but not in {ref_name}: {dev_keys - ref_keys}" | |
return out | |
def main(a_path, b_path): | |
a_path, b_path = Path(a_path), Path(b_path) | |
assert a_path.exists() | |
assert b_path.exists() | |
a_st = load_file(a_path) | |
b_st = load_file(b_path) | |
print(compare_keys(a_st.keys(), b_st.keys(), a_path, b_path)) | |
all_keys = sorted(list(a_st.keys() | b_st.keys())) | |
key_col_width = max(len(k) for k in all_keys) + 1 | |
for k in all_keys: | |
a_val = a_st.get(k, None) | |
b_val = b_st.get(k, None) | |
if a_val is not None and b_val is not None and (a_val == b_val).all(): | |
print(f"{k.ljust(key_col_width)} \033[37mIdentical\033[0m") | |
else: | |
diff = f"\033[34m{summarize_tensor(a_val).ljust(32)} \033[30m->\033[0m \033[36m{summarize_tensor(b_val).ljust(32)}\033[0m" | |
if a_val is not None and b_val is not None: | |
net_change = b_val.std() / a_val.std().add(1e-8) | |
net_change_str = f"{net_change.item():.4f}x" | |
else: | |
net_change = 1.0 | |
net_change_str = "" | |
if net_change > 1.5: | |
net_change_str = f"\033[31m{net_change_str}\033[0m" | |
elif net_change < 0.5: | |
net_change_str = f"\033[32m{net_change_str}\033[0m" | |
else: | |
net_change_str = f"\033[30m{net_change_str}\033[0m" | |
print(f"{k.ljust(key_col_width)} {diff} {net_change_str}") | |
if __name__ == "__main__": | |
import sys | |
assert len(sys.argv[1:]) == 2, f"Try: {sys.argv[0]} a.safetensors b.safetensors" | |
main(*sys.argv[1:]) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment