Skip to content

Instantly share code, notes, and snippets.

@eustlb
Last active September 5, 2024 16:08
Show Gist options
  • Select an option

  • Save eustlb/e6084c56e028533dd972473bd4526bf5 to your computer and use it in GitHub Desktop.

Select an option

Save eustlb/e6084c56e028533dd972473bd4526bf5 to your computer and use it in GitHub Desktop.
import os
import torch
import soundfile as sf
from parler_tts import ParlerTTSForConditionalGeneration
from transformers import AutoTokenizer
# caching allows ~50% compilation time reduction
# see https://docs.google.com/document/d/1y5CRfMLdwEoF1nTk9q8qEu1mgMUuUtvhklPKJ2emLU8/edit#heading=h.o2asbxsrp1ma
CURRENT_DIR = os.path.dirname(os.path.realpath(__file__))
os.environ["TORCHINDUCTOR_CACHE_DIR"] = os.path.join(CURRENT_DIR, "tmp")
torch._inductor.config.fx_graph_cache = True
torch._logging.set_logs(graph_breaks=True, recompiles=True)
# mind about this parameter ! should be >= 2 * number of compiled models
torch._dynamo.config.cache_size_limit = 15
def prepare_model_inputs(
description,
prompt,
description_tokenizer,
prompt_tokenizer,
device,
max_length_description=30,
max_length_prompt=50,
pad=False,
):
pad_args_description = {"padding": "max_length", "max_length": max_length_description} if pad else {}
pad_args_prompt = {"padding": "max_length", "max_length": max_length_prompt} if pad else {}
tokenized_description = description_tokenizer(description, return_tensors="pt", **pad_args_description)
input_ids = tokenized_description.input_ids.to(device)
attention_mask = tokenized_description.attention_mask.to(device)
tokenized_prompt = prompt_tokenizer(prompt, return_tensors="pt", **pad_args_prompt)
prompt_input_ids = tokenized_prompt.input_ids.to(device)
prompt_attention_mask = tokenized_prompt.attention_mask.to(device)
if pad:
model_kwargs = {
"input_ids": input_ids,
"attention_mask": attention_mask,
"prompt_input_ids": prompt_input_ids,
"prompt_attention_mask": prompt_attention_mask,
}
else:
model_kwargs = {
"input_ids": input_ids,
"prompt_input_ids": prompt_input_ids,
}
return model_kwargs
def next_power_of_2(x):
return 1 if x == 0 else 2**(x - 1).bit_length()
def load_prompts(filename):
with open(filename, 'r') as file:
phrases = file.read().strip().split('\n\n')
return phrases
model_name = "ylacombe/parler-tts-mini-jenny-30H"
torch_device = "cuda:0"
torch_dtype = torch.bfloat16
attn_implementation = "eager"
compile_mode = "default"
# load prompts
prompt_list = load_prompts('./prompts.txt')
# load model
model = ParlerTTSForConditionalGeneration.from_pretrained(
model_name,
attn_implementation=attn_implementation
).to(torch_device, dtype=torch_dtype)
model.generation_config.cache_implementation = "static"
model.forward = torch.compile(model.forward, mode=compile_mode, fullgraph=True)
# load tokenizers
padding_side = "left"
description_tokenizer = AutoTokenizer.from_pretrained(model_name)
prompt_tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side=padding_side)
description = "A female speaker with a slightly low-pitched voice delivers her words quite expressively, in a very confined sounding environment with clear audio quality. She speaks very fast."
pad_lengths = [16, 32, 128]
# warmup pass
# warmup up in reverted order to avoid recompilation
for pad_length in pad_lengths[::-1]:
model_kwargs = prepare_model_inputs(
description,
prompt_list[0],
description_tokenizer,
prompt_tokenizer,
torch_device,
max_length_prompt=pad_length,
pad=True
)
# 2 warmup steps for modes that capture CUDA graphs
n_steps = 1 if compile_mode == "default" else 2
print(f"Warming up length {pad_length} tokens...")
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
torch.cuda.synchronize()
start_event.record()
for _ in range(n_steps):
_ = model.generate(**model_kwargs)
end_event.record()
torch.cuda.synchronize()
print(f"Warmed up! Compilation time: {start_event.elapsed_time(end_event) * 1e-3:.3f} s")
# generation
for i, prompt in enumerate(prompt_list):
prompt = prompt_list[i]
nb_tokens = len(prompt_tokenizer(prompt).input_ids)
# pad to closest upper power of two
pad_length = next_power_of_2(nb_tokens)
model_kwargs = prepare_model_inputs(
description,
prompt,
description_tokenizer,
prompt_tokenizer,
torch_device,
max_length_prompt=pad_length,
pad=True
)
torch.manual_seed(0)
print(f"generating for length {pad_length}")
generation = model.generate(**model_kwargs)
audio_arr = generation.to(torch.float32).cpu().numpy().squeeze()
sf.write(f"./output_{i}.wav", audio_arr, model.config.sampling_rate)
Hey, how are you doing today?
Hey, how are you doing today? I hope you're having a fantastic day and that everything is going well for you so far.
Hey, how are you doing today? I hope you're having a fantastic day and that everything is going well for you so far. If you have a moment, I'd love to catch up and hear about what you've been up to lately, whether it's something exciting, a new project you're working on, or just how life has been treating you overall.
@eustlb
Copy link
Copy Markdown
Author

eustlb commented Jul 26, 2024

conda create --yes -n parler-tts-static python=3.11
conda activate parler-tts-static
pip install git+https://github.com/eustlb/parler-tts.git@add-static-cache

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment