Created
August 6, 2024 21:33
-
-
Save nucklearproject/2f170ef8400ad16bd1f84ff2b2c37cb0 to your computer and use it in GitHub Desktop.
Inference and batch txt processing TTS finetuning
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import gradio as gr | |
import os | |
import tempfile | |
import torch | |
import torchaudio | |
from pathlib import Path | |
from TTS.tts.configs.xtts_config import XttsConfig | |
from TTS.tts.models.xtts import Xtts | |
XTTS_MODEL = None | |
def clear_gpu_cache(): | |
if torch.cuda.is_available(): | |
torch.cuda.empty_cache() | |
def list_wav_files(folder_path): | |
wav_files = [os.path.join(folder_path, f) for f in os.listdir(folder_path) if f.endswith('.wav')] | |
return wav_files | |
def load_model(xtts_checkpoint, xtts_config, xtts_vocab, xtts_speaker): | |
global XTTS_MODEL | |
clear_gpu_cache() | |
if not xtts_checkpoint or not xtts_config or not xtts_vocab: | |
return "You need to run the previous steps or manually set the `XTTS checkpoint path`, `XTTS config path`, and `XTTS vocab path` fields !!" | |
config = XttsConfig() | |
config.load_json(xtts_config) | |
XTTS_MODEL = Xtts.init_from_config(config) | |
print("Loading XTTS model!") | |
XTTS_MODEL.load_checkpoint(config, checkpoint_path=xtts_checkpoint, vocab_path=xtts_vocab, speaker_file_path=xtts_speaker, use_deepspeed=False) | |
if torch.cuda.is_available(): | |
XTTS_MODEL.cuda() | |
print("Model Loaded!") | |
return "Model Loaded!" | |
def run_tts(lang, tts_text, speaker_audio_file, temperature, length_penalty, repetition_penalty, top_k, top_p, sentence_split, use_config): | |
if XTTS_MODEL is None or not speaker_audio_file: | |
return "You need to run the previous step to load the model !!", None, None | |
gpt_cond_latent, speaker_embedding = XTTS_MODEL.get_conditioning_latents(audio_path=speaker_audio_file, gpt_cond_len=XTTS_MODEL.config.gpt_cond_len, max_ref_length=XTTS_MODEL.config.max_ref_len, sound_norm_refs=XTTS_MODEL.config.sound_norm_refs) | |
if use_config: | |
out = XTTS_MODEL.inference( | |
text=tts_text, | |
language=lang, | |
gpt_cond_latent=gpt_cond_latent, | |
speaker_embedding=speaker_embedding, | |
temperature=0.6, | |
length_penalty=0.5, | |
repetition_penalty=4, | |
top_k=59, | |
top_p=0.7, | |
enable_text_splitting=True | |
) | |
else: | |
out = XTTS_MODEL.inference( | |
text=tts_text, | |
language=lang, | |
gpt_cond_latent=gpt_cond_latent, | |
speaker_embedding=speaker_embedding, | |
temperature=temperature, | |
length_penalty=length_penalty, | |
repetition_penalty=float(repetition_penalty), | |
top_k=top_k, | |
top_p=top_p, | |
enable_text_splitting=sentence_split | |
) | |
with tempfile.NamedTemporaryFile(dir="output", suffix=".wav", delete=False) as fp: | |
out["wav"] = torch.tensor(out["wav"]).unsqueeze(0) | |
out_path = fp.name | |
torchaudio.save(out_path, out["wav"], 24000) | |
return "Speech generated!", out_path, speaker_audio_file | |
def list_models_path(): | |
folder_path = "d:/TTS/xtts-webui/finetuned_models/" | |
all_entries = os.listdir(folder_path) | |
model_folders = [os.path.join(folder_path, f) for f in all_entries if os.path.isdir(os.path.join(folder_path, f))] | |
return model_folders | |
path_reference_wavs = "d:/proyectos_azure_tts_git/coqui-tts-src/voices/" | |
def load_params_tts(out_path): | |
out_path = Path(out_path) | |
ready_model_path = out_path / "ready" | |
vocab_path = ready_model_path / "vocab.json" | |
config_path = ready_model_path / "config.json" | |
speaker_path = ready_model_path / "speakers_xtts.pth" | |
reference_path = ready_model_path / "reference.wav" | |
model_path = ready_model_path / "model.pth" | |
if not model_path.exists(): | |
model_path = ready_model_path / "unoptimize_model.pth" | |
if not model_path.exists(): | |
return "Params for TTS not found", "", "", "", "", "" | |
return "Params for TTS loaded", model_path, config_path, vocab_path, speaker_path, reference_path | |
def process_batch(folder_path, tts_language, speaker_audio_file): | |
if XTTS_MODEL is None: | |
return "Model not loaded. Please load the model first." | |
text_files = [f for f in os.listdir(folder_path) if f.endswith('.txt')] | |
output_folder = Path(folder_path) / "batchaudio" | |
output_folder.mkdir(exist_ok=True) | |
for text_file in text_files: | |
with open(Path(folder_path) / text_file, 'r', encoding='utf-8') as file: | |
tts_text = file.read() | |
gpt_cond_latent, speaker_embedding = XTTS_MODEL.get_conditioning_latents(audio_path=speaker_audio_file, gpt_cond_len=XTTS_MODEL.config.gpt_cond_len, max_ref_length=XTTS_MODEL.config.max_ref_len, sound_norm_refs=XTTS_MODEL.config.sound_norm_refs) | |
out = XTTS_MODEL.inference( | |
text=tts_text, | |
language=tts_language, | |
gpt_cond_latent=gpt_cond_latent, | |
speaker_embedding=speaker_embedding, | |
temperature=XTTS_MODEL.config.temperature, | |
length_penalty=XTTS_MODEL.config.length_penalty, | |
repetition_penalty=XTTS_MODEL.config.repetition_penalty, | |
top_k=XTTS_MODEL.config.top_k, | |
top_p=XTTS_MODEL.config.top_p, | |
enable_text_splitting=True | |
) | |
output_path = output_folder / f"{Path(text_file).stem}.wav" | |
out["wav"] = torch.tensor(out["wav"]).unsqueeze(0) | |
torchaudio.save(output_path, out["wav"], 24000) | |
return "Batch processing completed!" | |
with gr.Blocks() as demo: | |
with gr.Tab("3 - Inference"): | |
with gr.Row(): | |
with gr.Column() as col1: | |
out_path = gr.Dropdown( | |
label="Output path (where data and checkpoints will be saved):", | |
choices=list_models_path(), | |
value="" | |
) | |
load_params_tts_btn = gr.Button(value="Load params for TTS from output folder") | |
xtts_checkpoint = gr.Textbox( | |
label="XTTS checkpoint path:", | |
value="", | |
) | |
xtts_config = gr.Textbox( | |
label="XTTS config path:", | |
value="", | |
) | |
xtts_vocab = gr.Textbox( | |
label="XTTS vocab path:", | |
value="", | |
) | |
xtts_speaker = gr.Textbox( | |
label="XTTS speaker path:", | |
value="", | |
) | |
progress_load = gr.Label( | |
label="Progress:" | |
) | |
load_btn = gr.Button(value="Step 3 - Load Fine-tuned XTTS model") | |
with gr.Column() as col2: | |
speaker_reference_audio = gr.Dropdown( | |
label="Speaker reference audio:", | |
value="", | |
choices=list_wav_files(path_reference_wavs), | |
) | |
tts_language = gr.Dropdown( | |
label="Language", | |
value="en", | |
choices=["en", "es", "fr", "de", "it", "pt"] | |
) | |
tts_text = gr.Textbox( | |
label="Input Text.", | |
lines=10, | |
value="", | |
) | |
with gr.Accordion("Advanced settings", open=False) as acr: | |
temperature = gr.Slider( | |
label="temperature", | |
minimum=0, | |
maximum=1, | |
step=0.05, | |
value=0.75, | |
) | |
length_penalty = gr.Slider( | |
label="length_penalty", | |
minimum=-10.0, | |
maximum=10.0, | |
step=0.5, | |
value=1, | |
) | |
repetition_penalty = gr.Slider( | |
label="repetition penalty", | |
minimum=1, | |
maximum=10, | |
step=0.5, | |
value=5, | |
) | |
top_k = gr.Slider( | |
label="top_k", | |
minimum=1, | |
maximum=100, | |
step=1, | |
value=50, | |
) | |
top_p = gr.Slider( | |
label="top_p", | |
minimum=0, | |
maximum=1, | |
step=0.05, | |
value=0.85, | |
) | |
sentence_split = gr.Checkbox( | |
label="Enable text splitting", | |
value=True, | |
) | |
use_config = gr.Checkbox( | |
label="Use Inference settings from config, if disabled use the settings above", | |
value=False, | |
) | |
tts_btn = gr.Button(value="Step 4 - Inference") | |
progress_gen = gr.Label( | |
label="Progress:" | |
) | |
tts_output_audio = gr.Audio(label="Generated Audio.") | |
reference_audio = gr.Audio(label="Reference audio used.") | |
with gr.Tab("4 - Batch Processing"): | |
batch_folder = gr.Textbox( | |
label="Folder containing text files to process:", | |
placeholder="Enter the path to the folder with text files" | |
) | |
speaker_reference_audio_batch = gr.Dropdown( | |
label="Speaker reference audio:", | |
value="", | |
choices=list_wav_files(path_reference_wavs), | |
) | |
tts_language_batch = gr.Dropdown( | |
label="Language", | |
value="en", | |
choices=["en", "es", "fr", "de", "it", "pt"] | |
) | |
batch_btn = gr.Button(value="Start Batch Processing") | |
progress_batch = gr.Label( | |
label="Batch processing status:" | |
) | |
load_params_tts_btn.click( | |
fn=load_params_tts, | |
inputs=[out_path], | |
outputs=[progress_load, xtts_checkpoint, xtts_config, xtts_vocab, xtts_speaker, speaker_reference_audio], | |
) | |
load_btn.click( | |
fn=load_model, | |
inputs=[xtts_checkpoint, xtts_config, xtts_vocab, xtts_speaker], | |
outputs=[progress_load], | |
) | |
tts_btn.click( | |
fn=run_tts, | |
inputs=[tts_language, tts_text, speaker_reference_audio, temperature, length_penalty, repetition_penalty, top_k, top_p, sentence_split, use_config], | |
outputs=[progress_gen, tts_output_audio, reference_audio], | |
) | |
batch_btn.click( | |
fn=process_batch, | |
inputs=[batch_folder, tts_language_batch, speaker_reference_audio_batch], | |
outputs=[progress_batch], | |
) | |
demo.launch( | |
share=False, | |
debug=False, | |
server_port=5003, | |
server_name="localhost" | |
) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment