Skip to content

Instantly share code, notes, and snippets.

@jagtesh
Forked from deckar01/ndxl.py
Created December 10, 2024 21:31
Show Gist options
  • Save jagtesh/cad139a89e264d1ad3fb69f34be6973d to your computer and use it in GitHub Desktop.
Save jagtesh/cad139a89e264d1ad3fb69f34be6973d to your computer and use it in GitHub Desktop.
NitroDiffusion + One Step Refiner
import torch
from diffusers import LCMScheduler
from diffusers import DiffusionPipeline, UNet2DConditionModel
from huggingface_hub import hf_hub_download
from safetensors.torch import load_file
class TimestepShiftLCMScheduler(LCMScheduler):
def __init__(self, *args, shifted_timestep=250, **kwargs):
super().__init__(*args, **kwargs)
self.register_to_config(shifted_timestep=shifted_timestep)
def set_timesteps(self, *args, **kwargs):
super().set_timesteps(*args, **kwargs)
self.origin_timesteps = self.timesteps.clone()
self.shifted_timesteps = (
self.timesteps * self.config.shifted_timestep /
self.config.num_train_timesteps
).long()
self.timesteps = self.shifted_timesteps
def step(self, model_output, timestep, sample, generator=None, return_dict=True):
if self.step_index is None:
self._init_step_index(timestep)
self.timesteps = self.origin_timesteps
output = super().step(model_output, timestep, sample, generator, return_dict)
self.timesteps = self.shifted_timesteps
return output
# Load model.
base_model_id = "stabilityai/stable-diffusion-xl-base-1.0"
refiner_model_id = "stabilityai/stable-diffusion-xl-refiner-1.0"
repo = "ChenDY/NitroFusion"
# NitroSD-Realism
ckpt = "nitrosd-realism_unet.safetensors"
unet = UNet2DConditionModel.from_config(base_model_id, subfolder="unet") \
.to("cuda", torch.float16)
unet.load_state_dict(load_file(hf_hub_download(repo, ckpt), device="cuda"))
scheduler = TimestepShiftLCMScheduler.from_pretrained(
base_model_id, subfolder="scheduler",
shifted_timestep=250,
)
scheduler.config.original_inference_steps = 4
base = DiffusionPipeline.from_pretrained(
base_model_id,
unet=unet,
scheduler=scheduler,
torch_dtype=torch.float16,
variant="fp16",
use_safetensors=True,
).to("cuda")
refiner = DiffusionPipeline.from_pretrained(
refiner_model_id,
text_encoder_2=base.text_encoder_2,
vae=base.vae,
torch_dtype=torch.float16,
use_safetensors=True,
variant="fp16",
).to("cuda")
high_noise_frac = 0.8
while True:
prompt = input('NDXL> ')
image = base(
prompt=prompt,
num_inference_steps=1,
guidance_scale=0,
denoising_end=high_noise_frac,
output_type="latent",
).images
image = refiner(
prompt=prompt,
num_inference_steps=1,
denoising_start=high_noise_frac,
image=image,
).images[0]
image.save('temp.png', 'PNG')
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment