Skip to content

Instantly share code, notes, and snippets.

@markrmiller
Created March 2, 2025 18:17
Show Gist options
  • Save markrmiller/789682346cfee826afd9f23f7fa45b68 to your computer and use it in GitHub Desktop.
Save markrmiller/789682346cfee826afd9f23f7fa45b68 to your computer and use it in GitHub Desktop.
@torch.no_grad()
def __call__(
self,
video_path: str,
audio_path: str,
video_out_path: str,
video_mask_path: str = None,
num_frames: int = 16,
video_fps: int = 25,
audio_sample_rate: int = 16000,
height: Optional[int] = None,
width: Optional[int] = None,
num_inference_steps: int = 25,
guidance_scale: float = 1.5,
weight_dtype: Optional[torch.dtype] = torch.float16,
eta: float = 0.0,
mask: str = "fix_mask",
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
callback_steps: Optional[int] = 1,
debug: bool = False,
video_skip_frames: int = 1,
video_num_workers: int = 4,
**kwargs,
):
"""
Function invoked when calling the pipeline for generation.
"""
import time
overall_start_time = time.time()
# Initialize the image processor here
self.image_processor = ImageProcessor(height, mask=mask, device="cuda", debug=debug)
print("\n" + "="*80)
print(f"Starting LatentSync pipeline with {num_inference_steps} inference steps")
print(f"Video: {video_path}")
print(f"Audio: {audio_path}")
print(f"Output: {video_out_path}")
print("="*80 + "\n")
# Extract audio features
print("\n[1/5] Extracting audio features...")
audio_start = time.time()
audio_samples = read_audio(audio_path)
audio_time = time.time() - audio_start
print(f"Audio feature extraction completed in {audio_time:.2f} seconds")
# Extract face from video
print("\n[2/5] Processing video frames...")
video_start = time.time()
faces, debug_faces, original_video_frames, boxes, affine_matrices = self.affine_transform_video(
video_path, video_skip_frames=video_skip_frames, video_num_workers=video_num_workers
)
video_time = time.time() - video_start
print(f"Video processing completed in {video_time:.2f} seconds")
# Get latents
print("\n[3/5] Encoding video frames to latent space...")
latent_start = time.time()
all_latents = self.prepare_latents(
1,
len(faces),
self.vae.config.latent_channels,
height or self.unet.config.sample_size * self.vae_scale_factor,
width or self.unet.config.sample_size * self.vae_scale_factor,
weight_dtype,
self._execution_device,
generator,
)
latent_time = time.time() - latent_start
print(f"Latent encoding completed in {latent_time:.2f} seconds")
# Prepare scheduler
self.scheduler.set_timesteps(num_inference_steps, device=self._execution_device)
timesteps = self.scheduler.timesteps
# Add noise to latents
all_latents = all_latents * self.scheduler.init_noise_sigma
# Prepare audio features
if self.unet.add_audio_layer:
whisper_feature = self.audio_encoder.audio2feat(audio_path)
whisper_chunks = self.audio_encoder.feature2chunks(feature_array=whisper_feature, fps=video_fps)
total_frames = min(len(faces), len(whisper_chunks))
else:
total_frames = len(faces)
num_inferences = math.ceil(total_frames / num_frames)
synced_video_frames = []
masked_video_frames = []
for i in tqdm.tqdm(range(num_inferences), desc="Doing inference..."):
# Calculate the actual number of frames for this chunk
start_idx = i * num_frames
end_idx = min(start_idx + num_frames, total_frames)
chunk_size = end_idx - start_idx
if self.unet.add_audio_layer:
audio_embeds = torch.stack(whisper_chunks[start_idx:end_idx])
audio_embeds = audio_embeds.to(device=self._execution_device, dtype=weight_dtype)
if guidance_scale > 1.0:
null_audio_embeds = torch.zeros_like(audio_embeds)
audio_embeds = torch.cat([null_audio_embeds, audio_embeds])
else:
audio_embeds = None
inference_faces = faces[start_idx:end_idx]
latents = all_latents[:, :, start_idx:end_idx]
pixel_values, masked_pixel_values, masks = self.image_processor.prepare_masks_and_masked_images(
inference_faces, affine_transform=False
)
# 7. Prepare mask latent variables
mask_latents, masked_image_latents = self.prepare_mask_latents(
masks,
masked_pixel_values,
height or self.unet.config.sample_size * self.vae_scale_factor,
width or self.unet.config.sample_size * self.vae_scale_factor,
weight_dtype,
self._execution_device,
generator,
guidance_scale > 1.0,
)
# 8. Prepare image latents
image_latents = self.prepare_image_latents(
pixel_values,
self._execution_device,
weight_dtype,
generator,
guidance_scale > 1.0,
)
# 9. Denoising loop
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
if isinstance(self.scheduler, (DDIMScheduler, PNDMScheduler)):
# Use more steps in the beginning of the sequence
base_steps = num_inference_steps
warmup_steps = min(base_steps + 5, 35) # Add 5 more steps for first chunk, max 35
if i == 0:
self.scheduler.set_timesteps(warmup_steps, device=self._execution_device)
else:
self.scheduler.set_timesteps(base_steps, device=self._execution_device)
timesteps = self.scheduler.timesteps
with self.progress_bar(total=num_inference_steps) as progress_bar:
for j, t in enumerate(timesteps):
# expand the latents if we are doing classifier free guidance
latent_model_input = torch.cat([latents] * 2) if guidance_scale > 1.0 else latents
# concat latents, mask, masked_image_latents in the channel dimension
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
latent_model_input = torch.cat(
[latent_model_input, mask_latents, masked_image_latents, image_latents], dim=1
)
# predict the noise residual
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=audio_embeds).sample
# perform guidance
if guidance_scale > 1.0:
noise_pred_uncond, noise_pred_audio = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_audio - noise_pred_uncond)
# compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(noise_pred, t, latents, **self.prepare_extra_step_kwargs(generator, eta)).prev_sample
# call the callback, if provided
if j == len(timesteps) - 1 or ((j + 1) > num_warmup_steps and (j + 1) % self.scheduler.order == 0):
progress_bar.update()
if callback is not None and j % callback_steps == 0:
callback(j, t, latents)
# Recover the pixel values
decoded_latents = self.decode_latents(latents)
decoded_latents = self.paste_surrounding_pixels_back(
decoded_latents, pixel_values, 1 - masks, self._execution_device, weight_dtype
)
synced_video_frames.append(decoded_latents)
synced_video_frames = self.restore_video(
torch.cat(synced_video_frames), original_video_frames, boxes, affine_matrices
)
audio_samples_remain_length = int(synced_video_frames.shape[0] / video_fps * audio_sample_rate)
audio_samples = audio_samples[:audio_samples_remain_length].cpu().numpy()
temp_dir = "temp"
if os.path.exists(temp_dir):
shutil.rmtree(temp_dir)
os.makedirs(temp_dir, exist_ok=True)
write_video(os.path.join(temp_dir, "video.mp4"), synced_video_frames, fps=25)
sf.write(os.path.join(temp_dir, "audio.wav"), audio_samples, audio_sample_rate)
command = f"ffmpeg -threads 32 -y -loglevel error -nostdin -i {os.path.join(temp_dir, 'video.mp4')} -i {os.path.join(temp_dir, 'audio.wav')} -c:v libx264 -c:a aac -q:v 0 -q:a 0 {video_out_path}"
subprocess.run(command, shell=True)
# Print overall statistics
total_time = time.time() - overall_start_time
print("\n" + "="*80)
print(f"LatentSync pipeline completed in {total_time:.2f} seconds")
print("="*80)`
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment