-
-
Save jagtesh/cad139a89e264d1ad3fb69f34be6973d to your computer and use it in GitHub Desktop.
NitroDiffusion + One Step Refiner
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
from diffusers import LCMScheduler | |
from diffusers import DiffusionPipeline, UNet2DConditionModel | |
from huggingface_hub import hf_hub_download | |
from safetensors.torch import load_file | |
class TimestepShiftLCMScheduler(LCMScheduler): | |
def __init__(self, *args, shifted_timestep=250, **kwargs): | |
super().__init__(*args, **kwargs) | |
self.register_to_config(shifted_timestep=shifted_timestep) | |
def set_timesteps(self, *args, **kwargs): | |
super().set_timesteps(*args, **kwargs) | |
self.origin_timesteps = self.timesteps.clone() | |
self.shifted_timesteps = ( | |
self.timesteps * self.config.shifted_timestep / | |
self.config.num_train_timesteps | |
).long() | |
self.timesteps = self.shifted_timesteps | |
def step(self, model_output, timestep, sample, generator=None, return_dict=True): | |
if self.step_index is None: | |
self._init_step_index(timestep) | |
self.timesteps = self.origin_timesteps | |
output = super().step(model_output, timestep, sample, generator, return_dict) | |
self.timesteps = self.shifted_timesteps | |
return output | |
# Load model. | |
base_model_id = "stabilityai/stable-diffusion-xl-base-1.0" | |
refiner_model_id = "stabilityai/stable-diffusion-xl-refiner-1.0" | |
repo = "ChenDY/NitroFusion" | |
# NitroSD-Realism | |
ckpt = "nitrosd-realism_unet.safetensors" | |
unet = UNet2DConditionModel.from_config(base_model_id, subfolder="unet") \ | |
.to("cuda", torch.float16) | |
unet.load_state_dict(load_file(hf_hub_download(repo, ckpt), device="cuda")) | |
scheduler = TimestepShiftLCMScheduler.from_pretrained( | |
base_model_id, subfolder="scheduler", | |
shifted_timestep=250, | |
) | |
scheduler.config.original_inference_steps = 4 | |
base = DiffusionPipeline.from_pretrained( | |
base_model_id, | |
unet=unet, | |
scheduler=scheduler, | |
torch_dtype=torch.float16, | |
variant="fp16", | |
use_safetensors=True, | |
).to("cuda") | |
refiner = DiffusionPipeline.from_pretrained( | |
refiner_model_id, | |
text_encoder_2=base.text_encoder_2, | |
vae=base.vae, | |
torch_dtype=torch.float16, | |
use_safetensors=True, | |
variant="fp16", | |
).to("cuda") | |
high_noise_frac = 0.8 | |
while True: | |
prompt = input('NDXL> ') | |
image = base( | |
prompt=prompt, | |
num_inference_steps=1, | |
guidance_scale=0, | |
denoising_end=high_noise_frac, | |
output_type="latent", | |
).images | |
image = refiner( | |
prompt=prompt, | |
num_inference_steps=1, | |
denoising_start=high_noise_frac, | |
image=image, | |
).images[0] | |
image.save('temp.png', 'PNG') |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment