Skip to content

Instantly share code, notes, and snippets.

@eustlb
Created July 23, 2024 16:30
Show Gist options
  • Select an option

  • Save eustlb/ae524f072dbe9cb6c1c3a6ff2486bc9c to your computer and use it in GitHub Desktop.

Select an option

Save eustlb/ae524f072dbe9cb6c1c3a6ff2486bc9c to your computer and use it in GitHub Desktop.
Reproduce a bug happening with torch 2.3.1 and compile.
import soundfile as sf
from parler_tts import ParlerTTSForConditionalGeneration
from transformers import AutoTokenizer
import torch
torch.manual_seed(0)
CUDA_DEVICE = 0
torch_device = f"cuda:{CUDA_DEVICE}"
attn_implementation = "eager"
model_name = "parler-tts/parler_tts_mini_v0.1"
tokenizer = AutoTokenizer.from_pretrained(model_name)
prompt = "Hey, how are you doing today?"
description = "A female speaker with a slightly low-pitched voice delivers her words quite expressively, in a very confined sounding environment with clear audio quality. She speaks very fast."
tokenized_description = tokenizer(description, return_tensors="pt")
input_ids = tokenized_description.input_ids.to(torch_device)
tokenized_prompt = tokenizer(prompt, return_tensors="pt")
prompt_input_ids = tokenized_prompt.input_ids.to(torch_device)
## 1
torch_dtype = torch.float16
model = ParlerTTSForConditionalGeneration.from_pretrained(
model_name,
attn_implementation=attn_implementation
).to(torch_device, dtype=torch_dtype)
generation = model.generate(
input_ids=input_ids,
prompt_input_ids=prompt_input_ids,
).to(torch.float32)
audio_arr = generation.cpu().numpy().squeeze()
sf.write("./out_1.wav", audio_arr, model.config.sampling_rate)
## 2
prompt = "Hey, how are you doing today?"
description = "A female speaker with a slightly low-pitched voice delivers her words quite expressively, in a very confined sounding environment with clear audio quality. She speaks very fast."
tokenized_description = tokenizer(description, return_tensors="pt")
input_ids = tokenized_description.input_ids.to(torch_device)
tokenized_prompt = tokenizer(prompt, return_tensors="pt")
prompt_input_ids = tokenized_prompt.input_ids.to(torch_device)
torch_dtype = torch.float16
model = ParlerTTSForConditionalGeneration.from_pretrained(
model_name,
attn_implementation=attn_implementation
).to(torch_device, dtype=torch_dtype)
model.generation_config.cache_implementation = "static"
model.forward = torch.compile(model.forward, mode="default", fullgraph=True)
generation = model.generate(
input_ids=input_ids,
prompt_input_ids=prompt_input_ids,
).to(torch.float32)
audio_arr = generation.cpu().numpy().squeeze()
sf.write("./out_2.wav", audio_arr, model.config.sampling_rate)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment