Created
June 2, 2023 17:44
-
-
Save takuma104/96d241f0fd6843c231791db0d4a5c4a9 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import json | |
from diffusers import StableDiffusionPipeline, DPMSolverMultistepScheduler | |
from diffusers.models.attention import Attention | |
from diffusers.models.attention_processor import ( | |
AttnProcessor2_0, | |
XFormersAttnProcessor, | |
LoRAAttnProcessor2_0, | |
LoRAXFormersAttnProcessor, | |
) | |
from PIL import Image | |
def image_grid(imgs, rows=2, cols=2): | |
w, h = imgs[0].size | |
grid = Image.new("RGB", size=(cols * w, rows * h)) | |
for i, img in enumerate(imgs): | |
grid.paste(img, box=(i % cols * w, i // cols * h)) | |
return grid | |
def on_off(cond): | |
return "ON" if cond else "OFF" | |
def print_memory_usage(width, height, batch, xformers, with_lora): | |
mem_bytes = torch.cuda.max_memory_allocated() | |
mem_MB = int(mem_bytes / (10**6)) | |
dict = { | |
"width": width, | |
"height": height, | |
"batch": batch, | |
"xformers": on_off(xformers), | |
"lora": on_off(with_lora), | |
"mem_MB": mem_MB, | |
} | |
print(json.dumps(dict)) | |
def check_attn_processor(root_module, klass): | |
for _, module in root_module.named_modules(): | |
if isinstance(module, Attention): | |
assert isinstance(module.processor, klass) | |
if __name__ == "__main__": | |
prompt = "masterpiece, best quality, 1girl, at dusk" | |
negative_prompt = ( | |
"(low quality, worst quality:1.4), (bad anatomy), (inaccurate limb:1.2), " | |
"bad composition, inaccurate eyes, extra digit, fewer digits, (extra arms:1.2), large breasts" | |
) | |
sd_model_id = "runwayml/stable-diffusion-v1-5" | |
lora_weight_model_id = "sayakpaul/civitai-light-shadow-lora" | |
lora_weight_name = "light_and_shadow.safetensors" | |
for xformers in [False, True]: | |
for batch in [4]: | |
for width, height in [(512, 768)]: | |
for with_lora in [False, True]: | |
torch.cuda.reset_peak_memory_stats() | |
pipe = StableDiffusionPipeline.from_pretrained( | |
sd_model_id, torch_dtype=torch.float16, safety_checker=None | |
).to("cuda") | |
pipe.scheduler = DPMSolverMultistepScheduler.from_config( | |
pipe.scheduler.config, use_karras_sigmas=True | |
) | |
if xformers: | |
pipe.enable_xformers_memory_efficient_attention() | |
check_attn_processor(pipe.unet, XFormersAttnProcessor) | |
else: | |
pipe.disable_xformers_memory_efficient_attention() | |
check_attn_processor(pipe.unet, AttnProcessor2_0) | |
# pipe.set_progress_bar_config(disable=True) | |
if with_lora: | |
pipe.load_lora_weights(lora_weight_model_id, weight_name=lora_weight_name) | |
if xformers: | |
check_attn_processor(pipe.unet, LoRAXFormersAttnProcessor) | |
else: | |
check_attn_processor(pipe.unet, LoRAAttnProcessor2_0) | |
images = pipe( | |
prompt=prompt, | |
negative_prompt=negative_prompt, | |
width=width, | |
height=height, | |
num_inference_steps=15, | |
num_images_per_prompt=batch, | |
generator=torch.manual_seed(0), | |
).images | |
image_grid(images).save(f'generated_xf-{on_off(xformers)}_lora-{on_off(with_lora)}.png') | |
print_memory_usage(width, height, batch, xformers, with_lora) | |
del pipe |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment