Created
March 2, 2025 18:17
-
-
Save markrmiller/789682346cfee826afd9f23f7fa45b68 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
@torch.no_grad() | |
def __call__( | |
self, | |
video_path: str, | |
audio_path: str, | |
video_out_path: str, | |
video_mask_path: str = None, | |
num_frames: int = 16, | |
video_fps: int = 25, | |
audio_sample_rate: int = 16000, | |
height: Optional[int] = None, | |
width: Optional[int] = None, | |
num_inference_steps: int = 25, | |
guidance_scale: float = 1.5, | |
weight_dtype: Optional[torch.dtype] = torch.float16, | |
eta: float = 0.0, | |
mask: str = "fix_mask", | |
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, | |
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, | |
callback_steps: Optional[int] = 1, | |
debug: bool = False, | |
video_skip_frames: int = 1, | |
video_num_workers: int = 4, | |
**kwargs, | |
): | |
""" | |
Function invoked when calling the pipeline for generation. | |
""" | |
import time | |
overall_start_time = time.time() | |
# Initialize the image processor here | |
self.image_processor = ImageProcessor(height, mask=mask, device="cuda", debug=debug) | |
print("\n" + "="*80) | |
print(f"Starting LatentSync pipeline with {num_inference_steps} inference steps") | |
print(f"Video: {video_path}") | |
print(f"Audio: {audio_path}") | |
print(f"Output: {video_out_path}") | |
print("="*80 + "\n") | |
# Extract audio features | |
print("\n[1/5] Extracting audio features...") | |
audio_start = time.time() | |
audio_samples = read_audio(audio_path) | |
audio_time = time.time() - audio_start | |
print(f"Audio feature extraction completed in {audio_time:.2f} seconds") | |
# Extract face from video | |
print("\n[2/5] Processing video frames...") | |
video_start = time.time() | |
faces, debug_faces, original_video_frames, boxes, affine_matrices = self.affine_transform_video( | |
video_path, video_skip_frames=video_skip_frames, video_num_workers=video_num_workers | |
) | |
video_time = time.time() - video_start | |
print(f"Video processing completed in {video_time:.2f} seconds") | |
# Get latents | |
print("\n[3/5] Encoding video frames to latent space...") | |
latent_start = time.time() | |
all_latents = self.prepare_latents( | |
1, | |
len(faces), | |
self.vae.config.latent_channels, | |
height or self.unet.config.sample_size * self.vae_scale_factor, | |
width or self.unet.config.sample_size * self.vae_scale_factor, | |
weight_dtype, | |
self._execution_device, | |
generator, | |
) | |
latent_time = time.time() - latent_start | |
print(f"Latent encoding completed in {latent_time:.2f} seconds") | |
# Prepare scheduler | |
self.scheduler.set_timesteps(num_inference_steps, device=self._execution_device) | |
timesteps = self.scheduler.timesteps | |
# Add noise to latents | |
all_latents = all_latents * self.scheduler.init_noise_sigma | |
# Prepare audio features | |
if self.unet.add_audio_layer: | |
whisper_feature = self.audio_encoder.audio2feat(audio_path) | |
whisper_chunks = self.audio_encoder.feature2chunks(feature_array=whisper_feature, fps=video_fps) | |
total_frames = min(len(faces), len(whisper_chunks)) | |
else: | |
total_frames = len(faces) | |
num_inferences = math.ceil(total_frames / num_frames) | |
synced_video_frames = [] | |
masked_video_frames = [] | |
for i in tqdm.tqdm(range(num_inferences), desc="Doing inference..."): | |
# Calculate the actual number of frames for this chunk | |
start_idx = i * num_frames | |
end_idx = min(start_idx + num_frames, total_frames) | |
chunk_size = end_idx - start_idx | |
if self.unet.add_audio_layer: | |
audio_embeds = torch.stack(whisper_chunks[start_idx:end_idx]) | |
audio_embeds = audio_embeds.to(device=self._execution_device, dtype=weight_dtype) | |
if guidance_scale > 1.0: | |
null_audio_embeds = torch.zeros_like(audio_embeds) | |
audio_embeds = torch.cat([null_audio_embeds, audio_embeds]) | |
else: | |
audio_embeds = None | |
inference_faces = faces[start_idx:end_idx] | |
latents = all_latents[:, :, start_idx:end_idx] | |
pixel_values, masked_pixel_values, masks = self.image_processor.prepare_masks_and_masked_images( | |
inference_faces, affine_transform=False | |
) | |
# 7. Prepare mask latent variables | |
mask_latents, masked_image_latents = self.prepare_mask_latents( | |
masks, | |
masked_pixel_values, | |
height or self.unet.config.sample_size * self.vae_scale_factor, | |
width or self.unet.config.sample_size * self.vae_scale_factor, | |
weight_dtype, | |
self._execution_device, | |
generator, | |
guidance_scale > 1.0, | |
) | |
# 8. Prepare image latents | |
image_latents = self.prepare_image_latents( | |
pixel_values, | |
self._execution_device, | |
weight_dtype, | |
generator, | |
guidance_scale > 1.0, | |
) | |
# 9. Denoising loop | |
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order | |
if isinstance(self.scheduler, (DDIMScheduler, PNDMScheduler)): | |
# Use more steps in the beginning of the sequence | |
base_steps = num_inference_steps | |
warmup_steps = min(base_steps + 5, 35) # Add 5 more steps for first chunk, max 35 | |
if i == 0: | |
self.scheduler.set_timesteps(warmup_steps, device=self._execution_device) | |
else: | |
self.scheduler.set_timesteps(base_steps, device=self._execution_device) | |
timesteps = self.scheduler.timesteps | |
with self.progress_bar(total=num_inference_steps) as progress_bar: | |
for j, t in enumerate(timesteps): | |
# expand the latents if we are doing classifier free guidance | |
latent_model_input = torch.cat([latents] * 2) if guidance_scale > 1.0 else latents | |
# concat latents, mask, masked_image_latents in the channel dimension | |
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) | |
latent_model_input = torch.cat( | |
[latent_model_input, mask_latents, masked_image_latents, image_latents], dim=1 | |
) | |
# predict the noise residual | |
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=audio_embeds).sample | |
# perform guidance | |
if guidance_scale > 1.0: | |
noise_pred_uncond, noise_pred_audio = noise_pred.chunk(2) | |
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_audio - noise_pred_uncond) | |
# compute the previous noisy sample x_t -> x_t-1 | |
latents = self.scheduler.step(noise_pred, t, latents, **self.prepare_extra_step_kwargs(generator, eta)).prev_sample | |
# call the callback, if provided | |
if j == len(timesteps) - 1 or ((j + 1) > num_warmup_steps and (j + 1) % self.scheduler.order == 0): | |
progress_bar.update() | |
if callback is not None and j % callback_steps == 0: | |
callback(j, t, latents) | |
# Recover the pixel values | |
decoded_latents = self.decode_latents(latents) | |
decoded_latents = self.paste_surrounding_pixels_back( | |
decoded_latents, pixel_values, 1 - masks, self._execution_device, weight_dtype | |
) | |
synced_video_frames.append(decoded_latents) | |
synced_video_frames = self.restore_video( | |
torch.cat(synced_video_frames), original_video_frames, boxes, affine_matrices | |
) | |
audio_samples_remain_length = int(synced_video_frames.shape[0] / video_fps * audio_sample_rate) | |
audio_samples = audio_samples[:audio_samples_remain_length].cpu().numpy() | |
temp_dir = "temp" | |
if os.path.exists(temp_dir): | |
shutil.rmtree(temp_dir) | |
os.makedirs(temp_dir, exist_ok=True) | |
write_video(os.path.join(temp_dir, "video.mp4"), synced_video_frames, fps=25) | |
sf.write(os.path.join(temp_dir, "audio.wav"), audio_samples, audio_sample_rate) | |
command = f"ffmpeg -threads 32 -y -loglevel error -nostdin -i {os.path.join(temp_dir, 'video.mp4')} -i {os.path.join(temp_dir, 'audio.wav')} -c:v libx264 -c:a aac -q:v 0 -q:a 0 {video_out_path}" | |
subprocess.run(command, shell=True) | |
# Print overall statistics | |
total_time = time.time() - overall_start_time | |
print("\n" + "="*80) | |
print(f"LatentSync pipeline completed in {total_time:.2f} seconds") | |
print("="*80)` |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment