Last active
June 15, 2025 22:52
-
-
Save Birch-san/b1d11539c37900056cdafc2985b213f9 to your computer and use it in GitHub Desktop.
Script for generating images from NAIv1
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from dataclasses import dataclass | |
from einops import rearrange | |
import re | |
import torch | |
from torch import BoolTensor, FloatTensor, IntTensor, LongTensor, inference_mode | |
from torch.nn.functional import pad | |
from itertools import islice | |
from typing import Generator, Iterable, Iterator, Optional, Protocol, TypeVar | |
from typing_extensions import override | |
from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel, UNet2DConditionOutput | |
from diffusers.models.autoencoders.autoencoder_kl import AutoencoderKL, DecoderOutput | |
from os import listdir | |
from pathlib import Path | |
from PIL import Image | |
from transformers.tokenization_utils_base import BatchEncoding, TruncationStrategy | |
from transformers.utils.generic import TensorType, PaddingStrategy | |
from transformers.modeling_outputs import BaseModelOutputWithPooling | |
from transformers.models.clip.modeling_clip import CLIPTextModel | |
from transformers.models.clip.tokenization_clip_fast import CLIPTokenizerFast | |
from k_diffusion.external import DiscreteEpsDDPMDenoiser | |
from k_diffusion.sampling import BrownianTreeNoiseSampler, get_sigmas_karras, sample_euler_ancestral | |
T = TypeVar('T') | |
# https://github.com/python/cpython/issues/98363 | |
def batched(iterable: Iterable[T], n: int) -> Generator[list[T], None, None]: | |
"Batch data into lists of length n. The last batch may be shorter." | |
# batched('ABCDEFG', 3) --> ABC DEF G | |
if n < 1: | |
raise ValueError("n must be >= 1") | |
it: Iterator[T] = iter(iterable) | |
while batch := list(islice(it, n)): | |
yield batch | |
def get_next_out_ix(out_dir: Path) -> int: | |
pattern = re.compile(r"^(\d+)_.*\.(png|jpg)$") | |
existing_samps: list[str] = [x for x in listdir(out_dir) if pattern.match(x)] | |
def get_ix(x: str) -> int: | |
return int(pattern.match(x).group(1)) | |
existing_samps.sort(key=get_ix) | |
return get_ix(existing_samps[-1]) + 1 if existing_samps else 0 | |
def get_betas( | |
num_train_timesteps: int = 1000, | |
beta_start: float = 0.00085, | |
beta_end: float = 0.012, | |
device: Optional[str | int | torch.device] = None, | |
dtype: torch.dtype = torch.float32, | |
) -> FloatTensor: | |
return torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=dtype, device=device) ** 2 | |
def get_alphas(betas: FloatTensor) -> FloatTensor: | |
return 1.0 - betas | |
class Denoiser(Protocol): | |
@staticmethod | |
def __call__(x: FloatTensor, timestep: LongTensor, encoder_hidden_states: FloatTensor, *args, **kwargs) -> FloatTensor: ... | |
@dataclass | |
class Adapter: | |
unet: UNet2DConditionModel | |
unet_dtype: torch.dtype = torch.float32 | |
sampling_dtype: torch.dtype = torch.float32 | |
@override | |
def __call__( | |
self, | |
x: FloatTensor, | |
timestep: LongTensor, | |
encoder_hidden_states: FloatTensor, | |
) -> FloatTensor: | |
out: UNet2DConditionOutput = self.unet( | |
x.to(self.unet_dtype), | |
timestep, | |
encoder_hidden_states=encoder_hidden_states, | |
) | |
return out.sample.type(self.sampling_dtype) | |
class HFEPSDenoiser(DiscreteEpsDDPMDenoiser): | |
inner_model: Denoiser | |
@override | |
def get_eps(self, x: FloatTensor, timestep: LongTensor, encoder_hidden_states: FloatTensor, *args, **kwargs) -> FloatTensor: | |
return self.inner_model( | |
x, | |
timestep, | |
encoder_hidden_states, | |
*args, | |
**kwargs, | |
) | |
@dataclass | |
class CFGWrapper: | |
delegate: Denoiser | |
@override | |
def __call__( | |
self, | |
x: FloatTensor, | |
timestep: LongTensor, | |
encoder_hidden_states: FloatTensor, | |
cfg_scale: float | FloatTensor, | |
*args, | |
**kwargs, | |
) -> FloatTensor: | |
out = self.delegate( | |
x.repeat(2, *(1,)*(x.ndim-1)), | |
timestep.repeat(2, *(1,)*(timestep.ndim-1)), | |
encoder_hidden_states.repeat_interleave(x.size(0), dim=0), | |
*args, | |
**kwargs, | |
) | |
uncond, cond = out.chunk(2) | |
return uncond + (cond - uncond) * cfg_scale | |
device = torch.device('cuda') | |
unet_dtype = torch.float16 | |
unet: UNet2DConditionModel = UNet2DConditionModel.from_pretrained( | |
'NovelAI/nai-anime-v1-full', | |
torch_dtype=unet_dtype, | |
use_safetensors=True, | |
subfolder='unet', | |
device_map={'': device.type}, | |
).eval() | |
adapted = Adapter(unet, unet_dtype=unet_dtype, sampling_dtype=torch.float32) | |
cfged = CFGWrapper(adapted) | |
vae_dtype = torch.float16 | |
vae: AutoencoderKL = AutoencoderKL.from_pretrained( | |
'NovelAI/nai-anime-v1-full', | |
torch_dtype=vae_dtype, | |
use_safetensors=True, | |
subfolder='vae', | |
device_map={'': device.type}, | |
).eval() | |
clip: CLIPTextModel = CLIPTextModel.from_pretrained( | |
"openai/clip-vit-large-patch14", | |
torch_dtype=torch.float16, | |
use_safetensors=True, | |
device_map={'': device.type}, | |
# skip final layer | |
num_hidden_layers=11, | |
).eval() | |
tokenizer: CLIPTokenizerFast = CLIPTokenizerFast.from_pretrained( | |
"openai/clip-vit-large-patch14", | |
) | |
betas = get_betas(device=device) | |
alphas_cumprod = get_alphas(betas).cumprod(-1) | |
x0_unet = HFEPSDenoiser(cfged, alphas_cumprod, quantize=True) | |
prompt = 'masterpiece, best quality, 1girl, purple eyes, short hair, ruffled blouse, red blouse, pleated skirt, blonde hair, green scarf, blunt bangs, blue skirt, fang, converse, black sneakers, no socks, small breasts, long sleeves, bob cut, medium skirt, masterpiece, best quality' | |
unprompt = 'lowres' | |
# unprompt = 'lowres, bad anatomy, bad hands, text, error, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality, normal quality, jpeg artifacts, signature, watermark, username, blurry, artist name' | |
nonspecial_ctx = tokenizer.model_max_length - 2 | |
max_segments = 3 | |
prompts: list[str] = [unprompt, prompt] | |
tok_out: BatchEncoding = tokenizer( | |
prompts, | |
padding=PaddingStrategy.LONGEST, | |
truncation=TruncationStrategy.LONGEST_FIRST, | |
add_special_tokens=False, | |
return_tensors=TensorType.PYTORCH, | |
max_length=nonspecial_ctx*max_segments, | |
) | |
input_ids: LongTensor = tok_out.input_ids | |
is_pad_tok: BoolTensor = input_ids == tokenizer.pad_token_id | |
has_pad_tok: BoolTensor = is_pad_tok.any(dim=-1, keepdim=True) | |
first_pad_ix: LongTensor = torch.argmax(is_pad_tok.int(), dim=-1, keepdim=True) | |
after_toks_ix: LongTensor = torch.where(has_pad_tok, first_pad_ix, input_ids.size(-1)) | |
tok_lens = after_toks_ix | |
required_segments = (tok_lens.max() + nonspecial_ctx - 1) // nonspecial_ctx | |
keep_len = required_segments.clamp_max(max_segments) * nonspecial_ctx | |
trimmed = input_ids[..., :keep_len] | |
toks = pad(trimmed, (0, keep_len-trimmed.size(-1)), value=tokenizer.pad_token_id) | |
toks = toks.unflatten(-1, (-1, nonspecial_ctx)) | |
toks = pad(toks, (1, 0), value=tokenizer.bos_token_id) | |
toks = pad(toks, (0, 1), value=tokenizer.eos_token_id) | |
toks = toks.to(device) | |
clip_batch_dims = toks.shape[:-1] | |
with inference_mode(): | |
clip_out: BaseModelOutputWithPooling = clip(toks.flatten(end_dim=-2)) | |
text_emb: FloatTensor = rearrange(clip_out.last_hidden_state, '(batch seg) seq dim -> batch (seg seq) dim', seg=required_segments).to(unet_dtype) | |
# height_px, width_px = 512, 512 | |
# height_px, width_px = 1216, 832 | |
height_px, width_px = 768, 512 | |
height_l, width_l = height_px//8, width_px//8 | |
bsz=2 | |
start_seed=42 | |
sample_count=4 | |
seeds: list[int] = list(range(start_seed, start_seed + sample_count)) | |
x = torch.empty( | |
(bsz, vae.config.latent_channels, height_l, width_l), | |
dtype=torch.float32, | |
device=device, | |
) | |
gen = torch.Generator(device=device) | |
for out_rand, seed in zip(x.unbind(), seeds): | |
torch.randn( | |
out_rand.shape, | |
dtype=torch.float32, | |
device=device, | |
generator=gen.manual_seed(seed), | |
out=out_rand, | |
) | |
steps=23 | |
sigmas: FloatTensor = get_sigmas_karras( | |
steps, | |
# sigh | |
sigma_min=x0_unet.sigma_min.cpu(), | |
sigma_max=x0_unet.sigma_max.cpu(), | |
device=device, | |
) | |
cfg_scale = 5 | |
extra_args = { | |
"encoder_hidden_states": text_emb, | |
"cfg_scale": cfg_scale, | |
} | |
out_dir = Path('out') | |
out_dir.mkdir(exist_ok=True, parents=True) | |
next_out_ix: int = get_next_out_ix(out_dir) | |
for batch_ix, batch_seeds in enumerate(batched(seeds, bsz)): | |
batch_start_ix: int = next_out_ix + batch_ix * bsz | |
out_ixs: list[int] = [batch_start_ix + ix for ix in range(len(batch_seeds))] | |
with inference_mode(): | |
noise_sampler = BrownianTreeNoiseSampler( | |
x, | |
sigma_max=sigmas[0], | |
sigma_min=sigmas[-2], | |
seed=batch_seeds, | |
transform=x0_unet.sigma_to_t, | |
) | |
x0: FloatTensor = sample_euler_ancestral( | |
x0_unet, | |
x, | |
sigmas, | |
extra_args=extra_args, | |
noise_sampler=noise_sampler, | |
) | |
x0 /= vae.config.scaling_factor | |
decoder_out: DecoderOutput = vae.decode(x0.to(vae_dtype)) | |
# -1 to 1, ish | |
rgb = decoder_out.sample | |
# [0, 1] range | |
# rgb.add_(1).div_(2).clamp_(0, 1) | |
# Add 0.5 after unnormalizing to [0, 255] to round to the nearest integer | |
rgb.mul_(127.5).add_(128).clamp_(0, 255) | |
rgb = rearrange(rgb, '... c h w -> ... h w c') | |
rgb_np = rgb.to("cpu", torch.uint8).numpy() | |
ims: list[Image.Image] = [Image.fromarray(im) for im in rgb_np] | |
for out_ix, seed, im in zip(out_ixs, batch_seeds, ims): | |
im.save(out_dir / f'{out_ix}_s{seed}.png') | |
pass |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
save the above Python script as
naiv1_inference.py
.save this text as
requirements.txt
:then run like so