Created
September 11, 2023 19:58
-
-
Save lucataco/338ed0efd2041ddf093f2bace84a6aee to your computer and use it in GitHub Desktop.
Replicate-LoRA-manual-load-weights
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
import torch | |
from diffusers import DiffusionPipeline, EulerDiscreteScheduler | |
from safetensors import safe_open | |
from dataset_and_utils import TokenEmbeddingsHandler | |
pipe = DiffusionPipeline.from_pretrained( | |
"stabilityai/stable-diffusion-xl-base-1.0", | |
torch_dtype=torch.float16, | |
use_safetensors=True, | |
variant="fp16", | |
).to("cuda") | |
# K_EULER Scheduler | |
pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config) | |
# Load safetensors | |
tensors = {} | |
with safe_open("weights/lora.safetensors", framework="pt", device="cuda") as f: | |
for key in f.keys(): | |
tensors[key] = f.get_tensor(key) | |
pipe.unet.load_state_dict(tensors, strict=False) # should take < 2 seconds | |
text_encoders = [pipe.text_encoder, pipe.text_encoder_2] | |
tokenizers = [pipe.tokenizer, pipe.tokenizer_2] | |
embhandler = TokenEmbeddingsHandler(text_encoders, tokenizers) | |
embhandler.load_embeddings("weights/embeddings.pti") | |
prompt="A <s0><s1> emoji of a man" | |
# seed=None | |
seed=57727 | |
if seed is None: | |
seed = int.from_bytes(os.urandom(2), "big") | |
print(f"Using seed: {seed}") | |
generator=torch.Generator("cuda").manual_seed(seed) | |
common_args = { | |
"prompt": prompt, | |
"guidance_scale": 7.5, | |
"generator": generator, | |
"num_inference_steps": 50, | |
} | |
image = pipe(**common_args).images[0] | |
image.save("output.png") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment