Skip to content

Instantly share code, notes, and snippets.

@Birch-san
Last active June 15, 2025 22:52
Show Gist options
  • Save Birch-san/b1d11539c37900056cdafc2985b213f9 to your computer and use it in GitHub Desktop.
Save Birch-san/b1d11539c37900056cdafc2985b213f9 to your computer and use it in GitHub Desktop.
Script for generating images from NAIv1
from dataclasses import dataclass
from einops import rearrange
import re
import torch
from torch import BoolTensor, FloatTensor, IntTensor, LongTensor, inference_mode
from torch.nn.functional import pad
from itertools import islice
from typing import Generator, Iterable, Iterator, Optional, Protocol, TypeVar
from typing_extensions import override
from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel, UNet2DConditionOutput
from diffusers.models.autoencoders.autoencoder_kl import AutoencoderKL, DecoderOutput
from os import listdir
from pathlib import Path
from PIL import Image
from transformers.tokenization_utils_base import BatchEncoding, TruncationStrategy
from transformers.utils.generic import TensorType, PaddingStrategy
from transformers.modeling_outputs import BaseModelOutputWithPooling
from transformers.models.clip.modeling_clip import CLIPTextModel
from transformers.models.clip.tokenization_clip_fast import CLIPTokenizerFast
from k_diffusion.external import DiscreteEpsDDPMDenoiser
from k_diffusion.sampling import BrownianTreeNoiseSampler, get_sigmas_karras, sample_euler_ancestral
T = TypeVar('T')
# https://github.com/python/cpython/issues/98363
def batched(iterable: Iterable[T], n: int) -> Generator[list[T], None, None]:
"Batch data into lists of length n. The last batch may be shorter."
# batched('ABCDEFG', 3) --> ABC DEF G
if n < 1:
raise ValueError("n must be >= 1")
it: Iterator[T] = iter(iterable)
while batch := list(islice(it, n)):
yield batch
def get_next_out_ix(out_dir: Path) -> int:
pattern = re.compile(r"^(\d+)_.*\.(png|jpg)$")
existing_samps: list[str] = [x for x in listdir(out_dir) if pattern.match(x)]
def get_ix(x: str) -> int:
return int(pattern.match(x).group(1))
existing_samps.sort(key=get_ix)
return get_ix(existing_samps[-1]) + 1 if existing_samps else 0
def get_betas(
num_train_timesteps: int = 1000,
beta_start: float = 0.00085,
beta_end: float = 0.012,
device: Optional[str | int | torch.device] = None,
dtype: torch.dtype = torch.float32,
) -> FloatTensor:
return torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=dtype, device=device) ** 2
def get_alphas(betas: FloatTensor) -> FloatTensor:
return 1.0 - betas
class Denoiser(Protocol):
@staticmethod
def __call__(x: FloatTensor, timestep: LongTensor, encoder_hidden_states: FloatTensor, *args, **kwargs) -> FloatTensor: ...
@dataclass
class Adapter:
unet: UNet2DConditionModel
unet_dtype: torch.dtype = torch.float32
sampling_dtype: torch.dtype = torch.float32
@override
def __call__(
self,
x: FloatTensor,
timestep: LongTensor,
encoder_hidden_states: FloatTensor,
) -> FloatTensor:
out: UNet2DConditionOutput = self.unet(
x.to(self.unet_dtype),
timestep,
encoder_hidden_states=encoder_hidden_states,
)
return out.sample.type(self.sampling_dtype)
class HFEPSDenoiser(DiscreteEpsDDPMDenoiser):
inner_model: Denoiser
@override
def get_eps(self, x: FloatTensor, timestep: LongTensor, encoder_hidden_states: FloatTensor, *args, **kwargs) -> FloatTensor:
return self.inner_model(
x,
timestep,
encoder_hidden_states,
*args,
**kwargs,
)
@dataclass
class CFGWrapper:
delegate: Denoiser
@override
def __call__(
self,
x: FloatTensor,
timestep: LongTensor,
encoder_hidden_states: FloatTensor,
cfg_scale: float | FloatTensor,
*args,
**kwargs,
) -> FloatTensor:
out = self.delegate(
x.repeat(2, *(1,)*(x.ndim-1)),
timestep.repeat(2, *(1,)*(timestep.ndim-1)),
encoder_hidden_states.repeat_interleave(x.size(0), dim=0),
*args,
**kwargs,
)
uncond, cond = out.chunk(2)
return uncond + (cond - uncond) * cfg_scale
device = torch.device('cuda')
unet_dtype = torch.float16
unet: UNet2DConditionModel = UNet2DConditionModel.from_pretrained(
'NovelAI/nai-anime-v1-full',
torch_dtype=unet_dtype,
use_safetensors=True,
subfolder='unet',
device_map={'': device.type},
).eval()
adapted = Adapter(unet, unet_dtype=unet_dtype, sampling_dtype=torch.float32)
cfged = CFGWrapper(adapted)
vae_dtype = torch.float16
vae: AutoencoderKL = AutoencoderKL.from_pretrained(
'NovelAI/nai-anime-v1-full',
torch_dtype=vae_dtype,
use_safetensors=True,
subfolder='vae',
device_map={'': device.type},
).eval()
clip: CLIPTextModel = CLIPTextModel.from_pretrained(
"openai/clip-vit-large-patch14",
torch_dtype=torch.float16,
use_safetensors=True,
device_map={'': device.type},
# skip final layer
num_hidden_layers=11,
).eval()
tokenizer: CLIPTokenizerFast = CLIPTokenizerFast.from_pretrained(
"openai/clip-vit-large-patch14",
)
betas = get_betas(device=device)
alphas_cumprod = get_alphas(betas).cumprod(-1)
x0_unet = HFEPSDenoiser(cfged, alphas_cumprod, quantize=True)
prompt = 'masterpiece, best quality, 1girl, purple eyes, short hair, ruffled blouse, red blouse, pleated skirt, blonde hair, green scarf, blunt bangs, blue skirt, fang, converse, black sneakers, no socks, small breasts, long sleeves, bob cut, medium skirt, masterpiece, best quality'
unprompt = 'lowres'
# unprompt = 'lowres, bad anatomy, bad hands, text, error, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality, normal quality, jpeg artifacts, signature, watermark, username, blurry, artist name'
nonspecial_ctx = tokenizer.model_max_length - 2
max_segments = 3
prompts: list[str] = [unprompt, prompt]
tok_out: BatchEncoding = tokenizer(
prompts,
padding=PaddingStrategy.LONGEST,
truncation=TruncationStrategy.LONGEST_FIRST,
add_special_tokens=False,
return_tensors=TensorType.PYTORCH,
max_length=nonspecial_ctx*max_segments,
)
input_ids: LongTensor = tok_out.input_ids
is_pad_tok: BoolTensor = input_ids == tokenizer.pad_token_id
has_pad_tok: BoolTensor = is_pad_tok.any(dim=-1, keepdim=True)
first_pad_ix: LongTensor = torch.argmax(is_pad_tok.int(), dim=-1, keepdim=True)
after_toks_ix: LongTensor = torch.where(has_pad_tok, first_pad_ix, input_ids.size(-1))
tok_lens = after_toks_ix
required_segments = (tok_lens.max() + nonspecial_ctx - 1) // nonspecial_ctx
keep_len = required_segments.clamp_max(max_segments) * nonspecial_ctx
trimmed = input_ids[..., :keep_len]
toks = pad(trimmed, (0, keep_len-trimmed.size(-1)), value=tokenizer.pad_token_id)
toks = toks.unflatten(-1, (-1, nonspecial_ctx))
toks = pad(toks, (1, 0), value=tokenizer.bos_token_id)
toks = pad(toks, (0, 1), value=tokenizer.eos_token_id)
toks = toks.to(device)
clip_batch_dims = toks.shape[:-1]
with inference_mode():
clip_out: BaseModelOutputWithPooling = clip(toks.flatten(end_dim=-2))
text_emb: FloatTensor = rearrange(clip_out.last_hidden_state, '(batch seg) seq dim -> batch (seg seq) dim', seg=required_segments).to(unet_dtype)
# height_px, width_px = 512, 512
# height_px, width_px = 1216, 832
height_px, width_px = 768, 512
height_l, width_l = height_px//8, width_px//8
bsz=2
start_seed=42
sample_count=4
seeds: list[int] = list(range(start_seed, start_seed + sample_count))
x = torch.empty(
(bsz, vae.config.latent_channels, height_l, width_l),
dtype=torch.float32,
device=device,
)
gen = torch.Generator(device=device)
for out_rand, seed in zip(x.unbind(), seeds):
torch.randn(
out_rand.shape,
dtype=torch.float32,
device=device,
generator=gen.manual_seed(seed),
out=out_rand,
)
steps=23
sigmas: FloatTensor = get_sigmas_karras(
steps,
# sigh
sigma_min=x0_unet.sigma_min.cpu(),
sigma_max=x0_unet.sigma_max.cpu(),
device=device,
)
cfg_scale = 5
extra_args = {
"encoder_hidden_states": text_emb,
"cfg_scale": cfg_scale,
}
out_dir = Path('out')
out_dir.mkdir(exist_ok=True, parents=True)
next_out_ix: int = get_next_out_ix(out_dir)
for batch_ix, batch_seeds in enumerate(batched(seeds, bsz)):
batch_start_ix: int = next_out_ix + batch_ix * bsz
out_ixs: list[int] = [batch_start_ix + ix for ix in range(len(batch_seeds))]
with inference_mode():
noise_sampler = BrownianTreeNoiseSampler(
x,
sigma_max=sigmas[0],
sigma_min=sigmas[-2],
seed=batch_seeds,
transform=x0_unet.sigma_to_t,
)
x0: FloatTensor = sample_euler_ancestral(
x0_unet,
x,
sigmas,
extra_args=extra_args,
noise_sampler=noise_sampler,
)
x0 /= vae.config.scaling_factor
decoder_out: DecoderOutput = vae.decode(x0.to(vae_dtype))
# -1 to 1, ish
rgb = decoder_out.sample
# [0, 1] range
# rgb.add_(1).div_(2).clamp_(0, 1)
# Add 0.5 after unnormalizing to [0, 255] to round to the nearest integer
rgb.mul_(127.5).add_(128).clamp_(0, 255)
rgb = rearrange(rgb, '... c h w -> ... h w c')
rgb_np = rgb.to("cpu", torch.uint8).numpy()
ims: list[Image.Image] = [Image.fromarray(im) for im in rgb_np]
for out_ix, seed, im in zip(out_ixs, batch_seeds, ims):
im.save(out_dir / f'{out_ix}_s{seed}.png')
pass
@Birch-san
Copy link
Author

Birch-san commented Jun 15, 2025

save the above Python script as naiv1_inference.py.

save this text as requirements.txt:

torch
torchvision
einops
k-diffusion
diffusers
transformers

then run like so

# install dependencies:
pip install -r requirements.txt

# generate images
python -m naiv1_inference

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment