Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Save cloneofsimo/68e2c2d9869b63a561811ed6ede4b790 to your computer and use it in GitHub Desktop.
Save cloneofsimo/68e2c2d9869b63a561811ed6ede4b790 to your computer and use it in GitHub Desktop.
import torch
from diffusers import DiffusionPipeline
from safetensors import safe_open
from dataset_and_utils import TokenEmbeddingsHandler
MODEL_CACHE = "./cache"
pipe = DiffusionPipeline.from_pretrained(
MODEL_CACHE,
torch_dtype=torch.float16,
variant="fp16",
).to("cuda")
tensors = {}
with safe_open("ft_masked/unet/checkpoint-1000.unet.safetensors", framework="pt", device="cuda") as f:
for key in f.keys():
tensors[key] = f.get_tensor(key)
pipe.unet.load_state_dict(tensors, strict=False) # should take < 2 seconds
text_encoders = [pipe.text_encoder, pipe.text_encoder_2]
tokenizers = [pipe.tokenizer, pipe.tokenizer_2]
embhandler = TokenEmbeddingsHandler(text_encoders, tokenizers)
embhandler.load_embeddings("/home/simo/sdxl-0.9-research-only/ft_masked/embeddings/checkpoint-1000.pti")
prompt = "3d render portrait of <TOK> ninja with water elemental, 4 K, cinematic, portrait".replace("<TOK>", "".join(embhandler.inserting_toks))
pipe(prompt).images[0]
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment