-
-
Save Fhrozen/60b38bd3ee23492d28b602a0c9f92217 to your computer and use it in GitHub Desktop.
#!/usr/bin/env python3 | |
"""Convert TTS to ONNX | |
Using ESPnet. | |
Test command: | |
python convert_tts2onnx.py --tts-tag espnet/kan-bayashi_ljspeech_vits | |
""" | |
import argparse | |
import logging | |
import sys | |
import numpy as np | |
import torch | |
import time | |
from typing import Dict | |
from typing import Optional | |
from espnet2.bin.tts_inference import Text2Speech | |
from espnet2.utils.types import str_or_none | |
import torch.nn.functional as F | |
def get_parser(): | |
parser = argparse.ArgumentParser( | |
description="", | |
formatter_class=argparse.ArgumentDefaultsHelpFormatter, | |
) | |
parser.add_argument( | |
"--tts-tag", | |
required=True, | |
type=str, | |
help="TTS tag (or Directory) for model located at huggingface/zenodo/local" | |
) | |
return parser | |
### Add this at espnet2/gan_tts/vits/vits.py | |
def inference_onnx( | |
self, | |
text: torch.Tensor, | |
text_lengths: torch.Tensor, | |
sids: Optional[torch.Tensor] = None, | |
spembs: Optional[torch.Tensor] = None, | |
lids: Optional[torch.Tensor] = None, | |
durations: Optional[torch.Tensor] = None, | |
noise_scale: float = 0.667, | |
noise_scale_dur: float = 0.8, | |
alpha: float = 1.0, | |
max_len: Optional[int] = None, | |
use_teacher_forcing: bool = False, | |
) -> Dict[str, torch.Tensor]: | |
"""Run inference for ONNX. | |
""" | |
if sids is not None: | |
sids = sids.view(1) | |
if lids is not None: | |
lids = lids.view(1) | |
if durations is not None: | |
durations = durations.view(1, 1, -1) | |
# inference | |
if use_teacher_forcing: | |
raise NotImplementedError | |
else: | |
wav, _, _ = self.generator.inference( | |
text=text, | |
text_lengths=text_lengths, | |
sids=sids, | |
spembs=spembs, | |
lids=lids, | |
dur=durations, | |
noise_scale=noise_scale, | |
noise_scale_dur=noise_scale_dur, | |
alpha=alpha, | |
max_len=max_len, | |
) | |
return wav.view(-1) | |
def test_onnx(): | |
logging.info('Test ONNX') | |
import onnxruntime as ort | |
this_text = 'Hello world, how are you doing' | |
this_text = preprocessing("<dummy>", dict(text=this_text))['text'] | |
this_text = this_text[None] | |
# this_len = np.array([this_text.shape[1]], dtype=int) | |
ort_sess = ort.InferenceSession('tts_model.onnx') | |
inname = [input.name for input in ort_sess.get_inputs()] | |
outname = [output.name for output in ort_sess.get_outputs()] | |
logging.info("inputs name: %s || outputs name: %s", inname, outname) | |
outputs = ort_sess.run(None, {'input_text': this_text}) | |
logging.info(type(outputs)) | |
if __name__ == "__main__": | |
# Logger | |
parser = get_parser() | |
args = parser.parse_args() | |
logfmt = "%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s" | |
logging.basicConfig(filename='onnx.log', encoding='utf-8', level=logging.INFO, format=logfmt) | |
# Load Pretrained model and testing wav generation | |
logging.info("Preparing pretrained model from: %s", args.tts_tag) | |
text2speech = Text2Speech.from_pretrained( | |
model_tag=str_or_none(args.tts_tag), | |
vocoder_tag=None, | |
device="cuda", | |
# Only for Tacotron 2 & Transformer | |
threshold=0.5, | |
# Only for Tacotron 2 | |
minlenratio=0.0, | |
maxlenratio=10.0, | |
use_att_constraint=False, | |
backward_window=1, | |
forward_window=3, | |
# Only for FastSpeech & FastSpeech2 & VITS | |
speed_control_alpha=1.0, | |
# Only for VITS | |
noise_scale=0.667, | |
noise_scale_dur=0.8, | |
) | |
text = 'Hello world' | |
logging.info("Generating test wav using the sequence: %s", text) | |
with torch.no_grad(): | |
start = time.time() | |
wav = text2speech(text)["wav"] | |
rtf = (time.time() - start) / (len(wav) / text2speech.fs) | |
logging.info(f"RTF = {rtf:5f}") | |
# Prepare modules for conversion | |
logging.info("Generate ONNX models") | |
with torch.no_grad(): | |
device = text2speech.device | |
preprocessing = text2speech.preprocess_fn | |
model_tts = text2speech.tts | |
# Replace forward with inference to avoid problems at ONNX generation | |
model_tts.forward = model_tts.inference_onnx | |
# Preprocessing data | |
preproc_text = preprocessing("<dummy>", dict(text=text))['text'] | |
preproc_text = torch.from_numpy(preproc_text).to(device).unsqueeze(0) | |
text_lengths = torch.tensor( | |
[preproc_text.size(1)], | |
dtype=torch.long, | |
device=preproc_text.device, | |
) | |
wav = model_tts(preproc_text, text_lengths) | |
logging.info(wav.shape) | |
inputs = (preproc_text, text_lengths) | |
# Generate TTS Model | |
torch.onnx.export( | |
model_tts, | |
inputs, | |
'tts_model.onnx', | |
export_params=True, | |
opset_version=13, | |
do_constant_folding=True, | |
verbose=True, | |
input_names=['input_text'], | |
output_names=['wav'], | |
dynamic_axes={ | |
'input_text': { | |
1: 'length' | |
}, | |
'wav': { | |
0: 'length' | |
} | |
} | |
) | |
test_onnx() | |
sys.exit(0) |
L83-L97
@Fhrozen I am getting this error:
AttributeError Traceback (most recent call last)
/tmp/ipykernel_1834/548140431.py in
5
6 # Replace forward with inference to avoid problems at ONNX generation
----> 7 model_tts.forward = model_tts.inference_onnx
8
9 # Preprocessing data
/opt/conda/lib/python3.7/site-packages/torch/nn/modules/module.py in getattr(self, name)
1176 return modules[name]
1177 raise AttributeError("'{}' object has no attribute '{}'".format(
-> 1178 type(self).name, name))
1179
1180 def setattr(self, name: str, value: Union[Tensor, 'Module']) -> None:
AttributeError: 'VITS' object has no attribute 'inference_onnx'
Hi @Fhrozen did you get a chance to look into it?
@sciai-ai, Sorry. I need a little longer time. I expect to implement the fixes for this or at last next weekend.
Hi @Fhrozen did you get a chance to work on it?
Sorry for the late response. I am checking for solutions . I will update u once finished
@Fhrozen Any updates?
Check this pls: https://github.com/Masao-Someki/espnet_onnx
@Fhrozen I also encountered this error, how did you solve it ?
onnxruntime.capi.onnxruntime_pybind11_state.RuntimeException: [ONNXRuntimeError] : 6 : RUNTIME_EXCEPTION : Non-zero status code returned while running Where node. Name:'Where_165' Status Message: /onnxruntime_src/onnxruntime/core/providers/cpu/math/element_wise_ops.h:497 void onnxruntime::BroadcastIterator::Init(ptrdiff_t, ptrdiff_t) axis == 1 || axis == largest was false. Attempting to broadcast an axis by a dimension other than 1. 8 by 19
where is the test_onxx method?