Created
December 1, 2022 10:43
-
-
Save Narsil/4b0d41f249178bab681c28942d1f9df5 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import sys | |
import os | |
import torch | |
from safetensors.torch import load_file, save_file | |
import datetime | |
from omegaconf import OmegaConf | |
sys.path.append(os.path.abspath(os.path.join(os.path.dirname( __file__ ), "repositories/stable-diffusion-stability-ai"))) | |
from ldm.modules.diffusionmodules.model import Model | |
from ldm.util import instantiate_from_config | |
# This is required because this feature hasn't been fully verified yet, but | |
# it's been tested on many different environments | |
os.environ["SAFETENSORS_FAST_GPU"] = "1" | |
pt_filename = "models/Stable-diffusion/sd14_fp16.ckpt" | |
st_filename = "models/Stable-diffusion/sd14_fp16.safetensors" | |
config = OmegaConf.load("v1-inference.yaml") | |
# CUDA startup out of the measurement | |
torch.zeros((2, 2)).cuda() | |
#!pip install accelerate | |
from accelerate import init_empty_weights | |
with init_empty_weights(): | |
model_pt = instantiate_from_config(config.model) | |
start_st = datetime.datetime.now() | |
time_st0 = datetime.datetime.now() | |
with init_empty_weights(): | |
model_st = instantiate_from_config(config.model) | |
time_st1 = datetime.datetime.now() | |
weights = load_file(st_filename, device="cuda:0") | |
time_st2 = datetime.datetime.now() | |
# model_st.half().to(torch.device("cuda:0")) | |
model_st.load_state_dict(weights, strict=False) | |
time_st3 = datetime.datetime.now() | |
load_time_st = datetime.datetime.now() - start_st | |
print(f"Loaded safetensors {load_time_st}") | |
model_st = None | |
start_pt = datetime.datetime.now() | |
time_pt0 = datetime.datetime.now() | |
with init_empty_weights(): | |
model_pt = instantiate_from_config(config.model) | |
time_pt1 = datetime.datetime.now() | |
weights = torch.load(pt_filename, map_location="cuda:0") | |
weights = weights.pop("state_dict", weights) | |
weights.pop("state_dict", None) | |
time_pt2 = datetime.datetime.now() | |
# model_pt.half().to(torch.device("cuda:0")) | |
model_pt.load_state_dict(weights, strict=False) | |
time_pt3 = datetime.datetime.now() | |
load_time_pt = datetime.datetime.now() - start_pt | |
print(f"Loaded pytorch {load_time_pt}") | |
model_pt = None | |
print(f"on GPU, safetensors is faster than pytorch by: {load_time_pt/load_time_st:.1f} X") | |
print(f"overall pt: {load_time_pt}") | |
print(f"overall st: {load_time_st}") | |
print(f"instantiate_from_config pt: {time_pt1-time_pt0}") | |
print(f"instantiate_from_config st: {time_st1-time_st0}") | |
print(f"load pt: {time_pt2-time_pt1}") | |
print(f"load st: {time_st2-time_st1}") | |
print(f"load_state_dict pt: {time_pt3-time_pt2}") | |
print(f"load_state_dict st: {time_st3-time_st2}") | |
test.py |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment