Created
July 23, 2024 16:30
-
-
Save eustlb/ae524f072dbe9cb6c1c3a6ff2486bc9c to your computer and use it in GitHub Desktop.
Reproduce a bug happening with torch 2.3.1 and compile.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import soundfile as sf | |
| from parler_tts import ParlerTTSForConditionalGeneration | |
| from transformers import AutoTokenizer | |
| import torch | |
| torch.manual_seed(0) | |
| CUDA_DEVICE = 0 | |
| torch_device = f"cuda:{CUDA_DEVICE}" | |
| attn_implementation = "eager" | |
| model_name = "parler-tts/parler_tts_mini_v0.1" | |
| tokenizer = AutoTokenizer.from_pretrained(model_name) | |
| prompt = "Hey, how are you doing today?" | |
| description = "A female speaker with a slightly low-pitched voice delivers her words quite expressively, in a very confined sounding environment with clear audio quality. She speaks very fast." | |
| tokenized_description = tokenizer(description, return_tensors="pt") | |
| input_ids = tokenized_description.input_ids.to(torch_device) | |
| tokenized_prompt = tokenizer(prompt, return_tensors="pt") | |
| prompt_input_ids = tokenized_prompt.input_ids.to(torch_device) | |
| ## 1 | |
| torch_dtype = torch.float16 | |
| model = ParlerTTSForConditionalGeneration.from_pretrained( | |
| model_name, | |
| attn_implementation=attn_implementation | |
| ).to(torch_device, dtype=torch_dtype) | |
| generation = model.generate( | |
| input_ids=input_ids, | |
| prompt_input_ids=prompt_input_ids, | |
| ).to(torch.float32) | |
| audio_arr = generation.cpu().numpy().squeeze() | |
| sf.write("./out_1.wav", audio_arr, model.config.sampling_rate) | |
| ## 2 | |
| prompt = "Hey, how are you doing today?" | |
| description = "A female speaker with a slightly low-pitched voice delivers her words quite expressively, in a very confined sounding environment with clear audio quality. She speaks very fast." | |
| tokenized_description = tokenizer(description, return_tensors="pt") | |
| input_ids = tokenized_description.input_ids.to(torch_device) | |
| tokenized_prompt = tokenizer(prompt, return_tensors="pt") | |
| prompt_input_ids = tokenized_prompt.input_ids.to(torch_device) | |
| torch_dtype = torch.float16 | |
| model = ParlerTTSForConditionalGeneration.from_pretrained( | |
| model_name, | |
| attn_implementation=attn_implementation | |
| ).to(torch_device, dtype=torch_dtype) | |
| model.generation_config.cache_implementation = "static" | |
| model.forward = torch.compile(model.forward, mode="default", fullgraph=True) | |
| generation = model.generate( | |
| input_ids=input_ids, | |
| prompt_input_ids=prompt_input_ids, | |
| ).to(torch.float32) | |
| audio_arr = generation.cpu().numpy().squeeze() | |
| sf.write("./out_2.wav", audio_arr, model.config.sampling_rate) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment