Skip to content

Instantly share code, notes, and snippets.

@twobob
Created October 5, 2022 17:58
Show Gist options
  • Save twobob/f36342b286006c3830409387d88cc482 to your computer and use it in GitHub Desktop.
Save twobob/f36342b286006c3830409387d88cc482 to your computer and use it in GitHub Desktop.
pruning code. Dance diffusion All props to Waifu
import torch
def prune_it(p):
print(f"prunin' in path: {p}")
size_initial = os.path.getsize(p)
nsd = dict()
sd = torch.load(p, map_location="cpu")
print(sd.keys())
for k in sd.keys():
if k != "optimizer_states":
nsd[k] = sd[k]
else:
print(f"removing optimizer states for path {p}")
if "global_step" in sd:
print(f"This is global step {sd['global_step']}.")
sd = nsd['state_dict'].copy()
new_sd = dict()
for k in sd:
new_sd[k] = sd[k].half()
nsd['state_dict'] = new_sd
fn = f"{os.path.splitext(p)[0]}-pruned.ckpt"
print(f"saving pruned checkpoint at: {fn}")
torch.save(nsd, fn)
newsize = os.path.getsize(fn)
MSG = f"New ckpt size: {newsize*1e-9:.2f} GB. " + \
f"Saved {(size_initial - newsize)*1e-9:.2f} GB by removing optimizer states"
print(MSG)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment