Last active
September 5, 2024 16:08
-
-
Save eustlb/e6084c56e028533dd972473bd4526bf5 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import os | |
| import torch | |
| import soundfile as sf | |
| from parler_tts import ParlerTTSForConditionalGeneration | |
| from transformers import AutoTokenizer | |
| # caching allows ~50% compilation time reduction | |
| # see https://docs.google.com/document/d/1y5CRfMLdwEoF1nTk9q8qEu1mgMUuUtvhklPKJ2emLU8/edit#heading=h.o2asbxsrp1ma | |
| CURRENT_DIR = os.path.dirname(os.path.realpath(__file__)) | |
| os.environ["TORCHINDUCTOR_CACHE_DIR"] = os.path.join(CURRENT_DIR, "tmp") | |
| torch._inductor.config.fx_graph_cache = True | |
| torch._logging.set_logs(graph_breaks=True, recompiles=True) | |
| # mind about this parameter ! should be >= 2 * number of compiled models | |
| torch._dynamo.config.cache_size_limit = 15 | |
| def prepare_model_inputs( | |
| description, | |
| prompt, | |
| description_tokenizer, | |
| prompt_tokenizer, | |
| device, | |
| max_length_description=30, | |
| max_length_prompt=50, | |
| pad=False, | |
| ): | |
| pad_args_description = {"padding": "max_length", "max_length": max_length_description} if pad else {} | |
| pad_args_prompt = {"padding": "max_length", "max_length": max_length_prompt} if pad else {} | |
| tokenized_description = description_tokenizer(description, return_tensors="pt", **pad_args_description) | |
| input_ids = tokenized_description.input_ids.to(device) | |
| attention_mask = tokenized_description.attention_mask.to(device) | |
| tokenized_prompt = prompt_tokenizer(prompt, return_tensors="pt", **pad_args_prompt) | |
| prompt_input_ids = tokenized_prompt.input_ids.to(device) | |
| prompt_attention_mask = tokenized_prompt.attention_mask.to(device) | |
| if pad: | |
| model_kwargs = { | |
| "input_ids": input_ids, | |
| "attention_mask": attention_mask, | |
| "prompt_input_ids": prompt_input_ids, | |
| "prompt_attention_mask": prompt_attention_mask, | |
| } | |
| else: | |
| model_kwargs = { | |
| "input_ids": input_ids, | |
| "prompt_input_ids": prompt_input_ids, | |
| } | |
| return model_kwargs | |
| def next_power_of_2(x): | |
| return 1 if x == 0 else 2**(x - 1).bit_length() | |
| def load_prompts(filename): | |
| with open(filename, 'r') as file: | |
| phrases = file.read().strip().split('\n\n') | |
| return phrases | |
| model_name = "ylacombe/parler-tts-mini-jenny-30H" | |
| torch_device = "cuda:0" | |
| torch_dtype = torch.bfloat16 | |
| attn_implementation = "eager" | |
| compile_mode = "default" | |
| # load prompts | |
| prompt_list = load_prompts('./prompts.txt') | |
| # load model | |
| model = ParlerTTSForConditionalGeneration.from_pretrained( | |
| model_name, | |
| attn_implementation=attn_implementation | |
| ).to(torch_device, dtype=torch_dtype) | |
| model.generation_config.cache_implementation = "static" | |
| model.forward = torch.compile(model.forward, mode=compile_mode, fullgraph=True) | |
| # load tokenizers | |
| padding_side = "left" | |
| description_tokenizer = AutoTokenizer.from_pretrained(model_name) | |
| prompt_tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side=padding_side) | |
| description = "A female speaker with a slightly low-pitched voice delivers her words quite expressively, in a very confined sounding environment with clear audio quality. She speaks very fast." | |
| pad_lengths = [16, 32, 128] | |
| # warmup pass | |
| # warmup up in reverted order to avoid recompilation | |
| for pad_length in pad_lengths[::-1]: | |
| model_kwargs = prepare_model_inputs( | |
| description, | |
| prompt_list[0], | |
| description_tokenizer, | |
| prompt_tokenizer, | |
| torch_device, | |
| max_length_prompt=pad_length, | |
| pad=True | |
| ) | |
| # 2 warmup steps for modes that capture CUDA graphs | |
| n_steps = 1 if compile_mode == "default" else 2 | |
| print(f"Warming up length {pad_length} tokens...") | |
| start_event = torch.cuda.Event(enable_timing=True) | |
| end_event = torch.cuda.Event(enable_timing=True) | |
| torch.cuda.synchronize() | |
| start_event.record() | |
| for _ in range(n_steps): | |
| _ = model.generate(**model_kwargs) | |
| end_event.record() | |
| torch.cuda.synchronize() | |
| print(f"Warmed up! Compilation time: {start_event.elapsed_time(end_event) * 1e-3:.3f} s") | |
| # generation | |
| for i, prompt in enumerate(prompt_list): | |
| prompt = prompt_list[i] | |
| nb_tokens = len(prompt_tokenizer(prompt).input_ids) | |
| # pad to closest upper power of two | |
| pad_length = next_power_of_2(nb_tokens) | |
| model_kwargs = prepare_model_inputs( | |
| description, | |
| prompt, | |
| description_tokenizer, | |
| prompt_tokenizer, | |
| torch_device, | |
| max_length_prompt=pad_length, | |
| pad=True | |
| ) | |
| torch.manual_seed(0) | |
| print(f"generating for length {pad_length}") | |
| generation = model.generate(**model_kwargs) | |
| audio_arr = generation.to(torch.float32).cpu().numpy().squeeze() | |
| sf.write(f"./output_{i}.wav", audio_arr, model.config.sampling_rate) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Hey, how are you doing today? | |
| Hey, how are you doing today? I hope you're having a fantastic day and that everything is going well for you so far. | |
| Hey, how are you doing today? I hope you're having a fantastic day and that everything is going well for you so far. If you have a moment, I'd love to catch up and hear about what you've been up to lately, whether it's something exciting, a new project you're working on, or just how life has been treating you overall. |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.