Created
July 14, 2023 06:13
-
-
Save jfischoff/5083630a183f989df82492042e21cd2f to your computer and use it in GitHub Desktop.
AnimateDiff simple runner
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from pathlib import Path | |
from animatediff.pipelines.pipeline_animation import AnimationPipeline | |
import torch | |
from diffusers import AutoencoderKL, DDIMScheduler, EulerAncestralDiscreteScheduler | |
from tqdm.auto import tqdm | |
from transformers import CLIPTextModel, CLIPTokenizer | |
from animatediff.models.unet import UNet3DConditionModel | |
from diffusers.utils.import_utils import is_xformers_available | |
from safetensors import safe_open | |
import diffusers | |
from animatediff.utils.convert_lora_safetensor_to_diffusers import convert_lora | |
import torchvision | |
from animatediff.utils.util import save_videos_grid | |
# Create a new AnimatedDiff object | |
pretrained_model_path = Path("/home/jonathan/models/v1-5") | |
### >>> create validation pipeline >>> ### | |
unet_additional_kwargs = { | |
"unet_use_cross_frame_attention": False, | |
"unet_use_temporal_attention": False, | |
"use_motion_module": True, | |
"motion_module_resolutions": [1, 2, 4, 8], | |
"motion_module_mid_block": False, | |
"motion_module_decoder_only": False, | |
"motion_module_type": "Vanilla", | |
"motion_module_kwargs": { | |
"num_attention_heads": 8, | |
"num_transformer_block": 1, | |
"attention_block_types": ["Temporal_Self", "Temporal_Self"], | |
"temporal_position_encoding": True, | |
"temporal_position_encoding_max_len": 24, | |
"temporal_attention_dim_div": 1, | |
}, | |
} | |
tokenizer = CLIPTokenizer.from_pretrained(pretrained_model_path, subfolder="tokenizer") | |
text_encoder = CLIPTextModel.from_pretrained(pretrained_model_path, subfolder="text_encoder") | |
vae = AutoencoderKL.from_pretrained(pretrained_model_path, subfolder="vae") | |
unet = UNet3DConditionModel.from_pretrained_2d(pretrained_model_path, | |
subfolder="unet", | |
unet_additional_kwargs=unet_additional_kwargs) | |
if is_xformers_available(): unet.enable_xformers_memory_efficient_attention() | |
else: assert False | |
scheduler_kwargs = { | |
"num_train_timesteps": 1000, | |
"beta_start": 0.00085, | |
"beta_end": 0.012, | |
"beta_schedule": "linear", | |
# "steps_offset": 1, | |
# "clip_sample": False, | |
} | |
pipeline = AnimationPipeline( | |
vae=vae, | |
text_encoder=text_encoder, | |
tokenizer=tokenizer, | |
unet=unet, | |
# scheduler=DDIMScheduler(**scheduler_kwargs), | |
scheduler=EulerAncestralDiscreteScheduler(**scheduler_kwargs), | |
).to("cuda") | |
# 1. unet ckpt | |
# 1.1 motion module | |
motion_module_state_dict = torch.load("models/Motion_Module/mm_sd_v14.ckpt", map_location="cpu") | |
if "global_step" in motion_module_state_dict: | |
print("hey") | |
missing, unexpected = pipeline.unet.load_state_dict(motion_module_state_dict, strict=False) | |
assert len(unexpected) == 0 | |
print("missing", len(missing)) | |
state_dict = {} | |
with safe_open(Path("/home/jonathan/stable-diffusion-webui/models/Lora/doodle.safetensors"), | |
framework="pt", device="cpu") as f: | |
for key in f.keys(): | |
state_dict[key] = f.get_tensor(key) | |
pipeline = convert_lora(pipeline, state_dict, alpha=0.7) | |
prompt = "doodle of a bear, childrens drawing style" | |
# prompt = "best quality, masterpiece, 1girl, looking at viewer, blurry background, upper body, contemporary, dress, vintage art, illustration" | |
negative_prompt = "deformed, bad anatomy, disfigured, poorly drawn face, mutation, mutated, extra limb, ugly, disgusting, poorly drawn hands, missing limb, floating limbs, disconnected limbs, malformed hands, blurry, ((((mutated hands and fingers)))), duplicate images, duplicate, (deformed, distorted, disfigured:1.3), poorly drawn, bad anatomy, wrong anatomy, extra limb, missing limb, floating limbs, (mutated hands and fingers:1.4), disconnected limbs, mutation, mutated, ugly, disgusting, blurry, amputation, EasyNegative, ng_deepnegative_v1_75t" | |
sample = pipeline( | |
prompt, | |
negative_prompt = negative_prompt, | |
num_inference_steps = 25, | |
guidance_scale = 7.5, | |
width = 512, | |
height = 512, | |
video_length = 16, | |
).videos | |
print("sample.shape", sample.shape) | |
save_videos_grid(sample, "output/sample.gif", n_rows=1) | |
img_tensor = sample.squeeze(0).permute(1, 0, 2, 3).clamp(0, 1) | |
# Move tensor to cpu and convert to numpy | |
img_tensor = img_tensor.cpu() | |
for i in range(img_tensor.shape[0]): | |
# torchvision only accepts tensors of the shape (C, H, W) | |
img = img_tensor[i] | |
# torchvision saves images in the range [0, 1], so no need for additional normalization | |
torchvision.utils.save_image(img, f'output/image_{i}.png') |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment