Skip to content

Instantly share code, notes, and snippets.

@Narsil
Created December 1, 2022 10:09
Show Gist options
  • Save Narsil/a27a3062fd33a8139872463c3566db2b to your computer and use it in GitHub Desktop.
Save Narsil/a27a3062fd33a8139872463c3566db2b to your computer and use it in GitHub Desktop.
import sys
import os
import torch
from safetensors.torch import load_file
import datetime
from omegaconf import OmegaConf
sys.path.append(os.path.abspath(os.path.join(os.path.dirname( __file__ ), "repositories/stable-diffusion-stability-ai")))
from ldm.modules.diffusionmodules.model import Model
from ldm.util import instantiate_from_config
# This is required because this feature hasn't been fully verified yet, but
# it's been tested on many different environments
os.environ["SAFETENSORS_FAST_GPU"] = "1"
pt_filename = "models/Stable-diffusion/sd14.ckpt"
st_filename = "models/Stable-diffusion/sd14.safetensors"
config = OmegaConf.load("v1-inference.yaml")
# CUDA startup out of the measurement
torch.zeros((2, 2)).cuda()
model_pt = instantiate_from_config(config.model)
model_pt.half().to(torch.device("cuda:0"))
start_st = datetime.datetime.now()
time_st0 = datetime.datetime.now()
# model_st = instantiate_from_config(config.model)
time_st1 = datetime.datetime.now()
weights = load_file(st_filename, device="cuda:0")
weights = weights.pop("state_dict", weights)
weights.pop("state_dict", None)
time_st2 = datetime.datetime.now()
# model_st.half().to(torch.device("cuda:0"))
model_pt.load_state_dict(weights, strict=False)
time_st3 = datetime.datetime.now()
load_time_st = datetime.datetime.now() - start_st
print(f"Loaded safetensors {load_time_st}")
# model_st = None
start_pt = datetime.datetime.now()
time_pt0 = datetime.datetime.now()
time_pt1 = datetime.datetime.now()
weights = torch.load(pt_filename, map_location="cuda:0")
weights = weights.pop("state_dict", weights)
weights.pop("state_dict", None)
time_pt2 = datetime.datetime.now()
model_pt.load_state_dict(weights, strict=False)
time_pt3 = datetime.datetime.now()
load_time_pt = datetime.datetime.now() - start_pt
print(f"Loaded pytorch {load_time_pt}")
print(f"on GPU, safetensors is faster than pytorch by: {load_time_pt/load_time_st:.1f} X")
print(f"overall pt: {load_time_pt}")
print(f"overall st: {load_time_st}")
print(f"instantiate_from_config pt: {time_pt1-time_pt0}")
print(f"instantiate_from_config st: {time_st1-time_st0}")
print(f"load pt: {time_pt2-time_pt1}")
print(f"load st: {time_st2-time_st1}")
print(f"load_state_dict pt: {time_pt3-time_pt2}")
print(f"load_state_dict st: {time_st3-time_st2}")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment