Last active
September 9, 2024 07:35
-
-
Save Fhrozen/60b38bd3ee23492d28b602a0c9f92217 to your computer and use it in GitHub Desktop.
Routine to generate an ONNX model for ESPnet 2 - Text2Speech model
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python3 | |
"""Convert TTS to ONNX | |
Using ESPnet. | |
Test command: | |
python convert_tts2onnx.py --tts-tag espnet/kan-bayashi_ljspeech_vits | |
""" | |
import argparse | |
import logging | |
import sys | |
import numpy as np | |
import torch | |
import time | |
from typing import Dict | |
from typing import Optional | |
from espnet2.bin.tts_inference import Text2Speech | |
from espnet2.utils.types import str_or_none | |
import torch.nn.functional as F | |
def get_parser(): | |
parser = argparse.ArgumentParser( | |
description="", | |
formatter_class=argparse.ArgumentDefaultsHelpFormatter, | |
) | |
parser.add_argument( | |
"--tts-tag", | |
required=True, | |
type=str, | |
help="TTS tag (or Directory) for model located at huggingface/zenodo/local" | |
) | |
return parser | |
### Add this at espnet2/gan_tts/vits/vits.py | |
def inference_onnx( | |
self, | |
text: torch.Tensor, | |
text_lengths: torch.Tensor, | |
sids: Optional[torch.Tensor] = None, | |
spembs: Optional[torch.Tensor] = None, | |
lids: Optional[torch.Tensor] = None, | |
durations: Optional[torch.Tensor] = None, | |
noise_scale: float = 0.667, | |
noise_scale_dur: float = 0.8, | |
alpha: float = 1.0, | |
max_len: Optional[int] = None, | |
use_teacher_forcing: bool = False, | |
) -> Dict[str, torch.Tensor]: | |
"""Run inference for ONNX. | |
""" | |
if sids is not None: | |
sids = sids.view(1) | |
if lids is not None: | |
lids = lids.view(1) | |
if durations is not None: | |
durations = durations.view(1, 1, -1) | |
# inference | |
if use_teacher_forcing: | |
raise NotImplementedError | |
else: | |
wav, _, _ = self.generator.inference( | |
text=text, | |
text_lengths=text_lengths, | |
sids=sids, | |
spembs=spembs, | |
lids=lids, | |
dur=durations, | |
noise_scale=noise_scale, | |
noise_scale_dur=noise_scale_dur, | |
alpha=alpha, | |
max_len=max_len, | |
) | |
return wav.view(-1) | |
def test_onnx(): | |
logging.info('Test ONNX') | |
import onnxruntime as ort | |
this_text = 'Hello world, how are you doing' | |
this_text = preprocessing("<dummy>", dict(text=this_text))['text'] | |
this_text = this_text[None] | |
# this_len = np.array([this_text.shape[1]], dtype=int) | |
ort_sess = ort.InferenceSession('tts_model.onnx') | |
inname = [input.name for input in ort_sess.get_inputs()] | |
outname = [output.name for output in ort_sess.get_outputs()] | |
logging.info("inputs name: %s || outputs name: %s", inname, outname) | |
outputs = ort_sess.run(None, {'input_text': this_text}) | |
logging.info(type(outputs)) | |
if __name__ == "__main__": | |
# Logger | |
parser = get_parser() | |
args = parser.parse_args() | |
logfmt = "%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s" | |
logging.basicConfig(filename='onnx.log', encoding='utf-8', level=logging.INFO, format=logfmt) | |
# Load Pretrained model and testing wav generation | |
logging.info("Preparing pretrained model from: %s", args.tts_tag) | |
text2speech = Text2Speech.from_pretrained( | |
model_tag=str_or_none(args.tts_tag), | |
vocoder_tag=None, | |
device="cuda", | |
# Only for Tacotron 2 & Transformer | |
threshold=0.5, | |
# Only for Tacotron 2 | |
minlenratio=0.0, | |
maxlenratio=10.0, | |
use_att_constraint=False, | |
backward_window=1, | |
forward_window=3, | |
# Only for FastSpeech & FastSpeech2 & VITS | |
speed_control_alpha=1.0, | |
# Only for VITS | |
noise_scale=0.667, | |
noise_scale_dur=0.8, | |
) | |
text = 'Hello world' | |
logging.info("Generating test wav using the sequence: %s", text) | |
with torch.no_grad(): | |
start = time.time() | |
wav = text2speech(text)["wav"] | |
rtf = (time.time() - start) / (len(wav) / text2speech.fs) | |
logging.info(f"RTF = {rtf:5f}") | |
# Prepare modules for conversion | |
logging.info("Generate ONNX models") | |
with torch.no_grad(): | |
device = text2speech.device | |
preprocessing = text2speech.preprocess_fn | |
model_tts = text2speech.tts | |
# Replace forward with inference to avoid problems at ONNX generation | |
model_tts.forward = model_tts.inference_onnx | |
# Preprocessing data | |
preproc_text = preprocessing("<dummy>", dict(text=text))['text'] | |
preproc_text = torch.from_numpy(preproc_text).to(device).unsqueeze(0) | |
text_lengths = torch.tensor( | |
[preproc_text.size(1)], | |
dtype=torch.long, | |
device=preproc_text.device, | |
) | |
wav = model_tts(preproc_text, text_lengths) | |
logging.info(wav.shape) | |
inputs = (preproc_text, text_lengths) | |
# Generate TTS Model | |
torch.onnx.export( | |
model_tts, | |
inputs, | |
'tts_model.onnx', | |
export_params=True, | |
opset_version=13, | |
do_constant_folding=True, | |
verbose=True, | |
input_names=['input_text'], | |
output_names=['wav'], | |
dynamic_axes={ | |
'input_text': { | |
1: 'length' | |
}, | |
'wav': { | |
0: 'length' | |
} | |
} | |
) | |
test_onnx() | |
sys.exit(0) |
Sorry for the late response. I am checking for solutions . I will update u once finished
@Fhrozen Any updates?
Check this pls: https://github.com/Masao-Someki/espnet_onnx
@Fhrozen I also encountered this error, how did you solve it ?
onnxruntime.capi.onnxruntime_pybind11_state.RuntimeException: [ONNXRuntimeError] : 6 : RUNTIME_EXCEPTION : Non-zero status code returned while running Where node. Name:'Where_165' Status Message: /onnxruntime_src/onnxruntime/core/providers/cpu/math/element_wise_ops.h:497 void onnxruntime::BroadcastIterator::Init(ptrdiff_t, ptrdiff_t) axis == 1 || axis == largest was false. Attempting to broadcast an axis by a dimension other than 1. 8 by 19
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Hi @Fhrozen did you get a chance to work on it?