-
-
Save Fhrozen/60b38bd3ee23492d28b602a0c9f92217 to your computer and use it in GitHub Desktop.
#!/usr/bin/env python3 | |
"""Convert TTS to ONNX | |
Using ESPnet. | |
Test command: | |
python convert_tts2onnx.py --tts-tag espnet/kan-bayashi_ljspeech_vits | |
""" | |
import argparse | |
import logging | |
import sys | |
import numpy as np | |
import torch | |
import time | |
from typing import Dict | |
from typing import Optional | |
from espnet2.bin.tts_inference import Text2Speech | |
from espnet2.utils.types import str_or_none | |
import torch.nn.functional as F | |
def get_parser(): | |
parser = argparse.ArgumentParser( | |
description="", | |
formatter_class=argparse.ArgumentDefaultsHelpFormatter, | |
) | |
parser.add_argument( | |
"--tts-tag", | |
required=True, | |
type=str, | |
help="TTS tag (or Directory) for model located at huggingface/zenodo/local" | |
) | |
return parser | |
### Add this at espnet2/gan_tts/vits/vits.py | |
def inference_onnx( | |
self, | |
text: torch.Tensor, | |
text_lengths: torch.Tensor, | |
sids: Optional[torch.Tensor] = None, | |
spembs: Optional[torch.Tensor] = None, | |
lids: Optional[torch.Tensor] = None, | |
durations: Optional[torch.Tensor] = None, | |
noise_scale: float = 0.667, | |
noise_scale_dur: float = 0.8, | |
alpha: float = 1.0, | |
max_len: Optional[int] = None, | |
use_teacher_forcing: bool = False, | |
) -> Dict[str, torch.Tensor]: | |
"""Run inference for ONNX. | |
""" | |
if sids is not None: | |
sids = sids.view(1) | |
if lids is not None: | |
lids = lids.view(1) | |
if durations is not None: | |
durations = durations.view(1, 1, -1) | |
# inference | |
if use_teacher_forcing: | |
raise NotImplementedError | |
else: | |
wav, _, _ = self.generator.inference( | |
text=text, | |
text_lengths=text_lengths, | |
sids=sids, | |
spembs=spembs, | |
lids=lids, | |
dur=durations, | |
noise_scale=noise_scale, | |
noise_scale_dur=noise_scale_dur, | |
alpha=alpha, | |
max_len=max_len, | |
) | |
return wav.view(-1) | |
def test_onnx(): | |
logging.info('Test ONNX') | |
import onnxruntime as ort | |
this_text = 'Hello world, how are you doing' | |
this_text = preprocessing("<dummy>", dict(text=this_text))['text'] | |
this_text = this_text[None] | |
# this_len = np.array([this_text.shape[1]], dtype=int) | |
ort_sess = ort.InferenceSession('tts_model.onnx') | |
inname = [input.name for input in ort_sess.get_inputs()] | |
outname = [output.name for output in ort_sess.get_outputs()] | |
logging.info("inputs name: %s || outputs name: %s", inname, outname) | |
outputs = ort_sess.run(None, {'input_text': this_text}) | |
logging.info(type(outputs)) | |
if __name__ == "__main__": | |
# Logger | |
parser = get_parser() | |
args = parser.parse_args() | |
logfmt = "%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s" | |
logging.basicConfig(filename='onnx.log', encoding='utf-8', level=logging.INFO, format=logfmt) | |
# Load Pretrained model and testing wav generation | |
logging.info("Preparing pretrained model from: %s", args.tts_tag) | |
text2speech = Text2Speech.from_pretrained( | |
model_tag=str_or_none(args.tts_tag), | |
vocoder_tag=None, | |
device="cuda", | |
# Only for Tacotron 2 & Transformer | |
threshold=0.5, | |
# Only for Tacotron 2 | |
minlenratio=0.0, | |
maxlenratio=10.0, | |
use_att_constraint=False, | |
backward_window=1, | |
forward_window=3, | |
# Only for FastSpeech & FastSpeech2 & VITS | |
speed_control_alpha=1.0, | |
# Only for VITS | |
noise_scale=0.667, | |
noise_scale_dur=0.8, | |
) | |
text = 'Hello world' | |
logging.info("Generating test wav using the sequence: %s", text) | |
with torch.no_grad(): | |
start = time.time() | |
wav = text2speech(text)["wav"] | |
rtf = (time.time() - start) / (len(wav) / text2speech.fs) | |
logging.info(f"RTF = {rtf:5f}") | |
# Prepare modules for conversion | |
logging.info("Generate ONNX models") | |
with torch.no_grad(): | |
device = text2speech.device | |
preprocessing = text2speech.preprocess_fn | |
model_tts = text2speech.tts | |
# Replace forward with inference to avoid problems at ONNX generation | |
model_tts.forward = model_tts.inference_onnx | |
# Preprocessing data | |
preproc_text = preprocessing("<dummy>", dict(text=text))['text'] | |
preproc_text = torch.from_numpy(preproc_text).to(device).unsqueeze(0) | |
text_lengths = torch.tensor( | |
[preproc_text.size(1)], | |
dtype=torch.long, | |
device=preproc_text.device, | |
) | |
wav = model_tts(preproc_text, text_lengths) | |
logging.info(wav.shape) | |
inputs = (preproc_text, text_lengths) | |
# Generate TTS Model | |
torch.onnx.export( | |
model_tts, | |
inputs, | |
'tts_model.onnx', | |
export_params=True, | |
opset_version=13, | |
do_constant_folding=True, | |
verbose=True, | |
input_names=['input_text'], | |
output_names=['wav'], | |
dynamic_axes={ | |
'input_text': { | |
1: 'length' | |
}, | |
'wav': { | |
0: 'length' | |
} | |
} | |
) | |
test_onnx() | |
sys.exit(0) |
@Fhrozen I am getting the same error with opset_version=11/12/13
RuntimeError: Exporting the operator _thnn_fused_lstm_cell to ONNX opset version 12 is not supported. Please feel free to request support or submit a pull request on PyTorch GitHub.
Could you tell me, which pretrained model are you using?
I am using tacotron2 which I believe uses LSTM
There is an exactly same situation for Tacotron2 as yours in this issue thread: pytorch/pytorch#25533
You will need to check the comments, and specially these links: https://colab.research.google.com/drive/1n6t-mtM8DBBiP8er6jpU_7KxwYBeNJYf?usp=sharing
https://github.com/NVIDIA/DeepLearningExamples/blob/master/PyTorch/SpeechSynthesis/Tacotron2/tensorrt/convert_tacotron22onnx.py
@Fhrozen Thanks for sharing these, i googled it too and found the same threads. It does not seem very straight forward to do this :(
If your ONXX conversion works with joint model, then I might retrain my data on it.
This code is for the joint (Transformer/FastSpeech + ParallelWGan) but it still has the problems of the python list/dicts that I will check later. Probably for VITS is easier (I did not try so I cannot confirm).
@Fhrozen was wondering if you manage to get it working. Thanks
@sciai-ai, I think it is about 50% on the way. It needs to change so many parts because there are variables that are changed from torch -> onnx, and these changes generate constants that later generate errors.
If you can, test the updated code and see if you may be able to fix any additional issue.
The current warnings I am having are:
/export/db/espnet/converter/espnet/nets/pytorch_backend/nets_utils.py:154: TracerWarning: Converting a tensor to a Python list might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
lengths = lengths.tolist()
/export/db/espnet/converter/espnet/nets/pytorch_backend/transformer/embedding.py:198: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
if self.pe.size(1) >= x.size(1) * 2 - 1:
/export/db/espnet/converter/espnet/nets/pytorch_backend/transformer/embedding.py:239: UserWarning: __floordiv__ is deprecated, and its behavior will change in a future version of pytorch. It currently rounds toward 0 (like the 'trunc' function NOT 'floor'). This results in incorrect rounding for negative values. To keep the current behavior, use torch.div(a, b, rounding_mode='trunc'), or for actual floor division, use torch.div(a, b, rounding_mode='floor').
self.pe.size(1) // 2 - x.size(1) + 1 : self.pe.size(1) // 2 + x.size(1),
/export/db/espnet/converter/espnet/nets/pytorch_backend/transformer/attention.py:256: UserWarning: __floordiv__ is deprecated, and its behavior will change in a future version of pytorch. It currently rounds toward 0 (like the 'trunc' function NOT 'floor'). This results in incorrect rounding for negative values. To keep the current behavior, use torch.div(a, b, rounding_mode='trunc'), or for actual floor division, use torch.div(a, b, rounding_mode='floor').
:, :, :, : x.size(-1) // 2 + 1
/export/db/espnet/converter/espnet/nets/pytorch_backend/transformer/attention.py:81: TracerWarning: torch.tensor results are registered as constants in the trace. You can safely ignore this warning if you use this function to create tensors out of constant variables that would be the same every time you call this function. In any other case, this might cause the trace to be incorrect.
numpy.finfo(torch.tensor(0, dtype=scores.dtype).numpy().dtype).min
/export/db/espnet/converter/espnet/nets/pytorch_backend/transformer/attention.py:81: TracerWarning: Converting a tensor to a NumPy array might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
numpy.finfo(torch.tensor(0, dtype=scores.dtype).numpy().dtype).min
/export/db/espnet/converter/espnet2/gan_tts/vits/text_encoder.py:139: UserWarning: __floordiv__ is deprecated, and its behavior will change in a future version of pytorch. It currently rounds toward 0 (like the 'trunc' function NOT 'floor'). This results in incorrect rounding for negative values. To keep the current behavior, use torch.div(a, b, rounding_mode='trunc'), or for actual floor division, use torch.div(a, b, rounding_mode='floor').
m, logs = stats.split(stats.size(1) // 2, dim=1)
/export/db/espnet/converter/espnet2/gan_tts/vits/flow.py:285: UserWarning: __floordiv__ is deprecated, and its behavior will change in a future version of pytorch. It currently rounds toward 0 (like the 'trunc' function NOT 'floor'). This results in incorrect rounding for negative values. To keep the current behavior, use torch.div(a, b, rounding_mode='trunc'), or for actual floor division, use torch.div(a, b, rounding_mode='floor').
xa, xb = x.split(x.size(1) // 2, 1)
/export/db/espnet/converter/espnet2/gan_tts/vits/transform.py:118: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
if torch.min(inputs) < left or torch.max(inputs) > right:
/export/db/espnet/converter/espnet2/gan_tts/vits/transform.py:123: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
if min_bin_width * num_bins > 1.0:
/export/db/espnet/converter/espnet2/gan_tts/vits/transform.py:125: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
if min_bin_height * num_bins > 1.0:
/export/db/espnet/converter/espnet2/gan_tts/vits/transform.py:175: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
assert (discriminant >= 0).all()
/export/db/espnet/converter/espnet2/gan_tts/vits/residual_coupling.py:211: UserWarning: __floordiv__ is deprecated, and its behavior will change in a future version of pytorch. It currently rounds toward 0 (like the 'trunc' function NOT 'floor'). This results in incorrect rounding for negative values. To keep the current behavior, use torch.div(a, b, rounding_mode='trunc'), or for actual floor division, use torch.div(a, b, rounding_mode='floor').
xa, xb = x.split(x.size(1) // 2, dim=1)
/export/db/espnet/converter/espnet2/gan_tts/wavenet/residual_block.py:140: UserWarning: __floordiv__ is deprecated, and its behavior will change in a future version of pytorch. It currently rounds toward 0 (like the 'trunc' function NOT 'floor'). This results in incorrect rounding for negative values. To keep the current behavior, use torch.div(a, b, rounding_mode='trunc'), or for actual floor division, use torch.div(a, b, rounding_mode='floor').
xa, xb = x.split(x.size(splitdim) // 2, dim=splitdim)
When you run the test_onnx
method, you will have an error:
onnxruntime.capi.onnxruntime_pybind11_state.RuntimeException: [ONNXRuntimeError] : 6 : RUNTIME_EXCEPTION : Non-zero status code returned while running Where node. Name:'Where_165' Status Message: /onnxruntime_src/onnxruntime/core/providers/cpu/math/element_wise_ops.h:497 void onnxruntime::BroadcastIterator::Init(ptrdiff_t, ptrdiff_t) axis == 1 || axis == largest was false. Attempting to broadcast an axis by a dimension other than 1. 8 by 19
I will try to check at weekend
where is the test_onxx method?
L83-L97
@Fhrozen I am getting this error:
AttributeError Traceback (most recent call last)
/tmp/ipykernel_1834/548140431.py in
5
6 # Replace forward with inference to avoid problems at ONNX generation
----> 7 model_tts.forward = model_tts.inference_onnx
8
9 # Preprocessing data
/opt/conda/lib/python3.7/site-packages/torch/nn/modules/module.py in getattr(self, name)
1176 return modules[name]
1177 raise AttributeError("'{}' object has no attribute '{}'".format(
-> 1178 type(self).name, name))
1179
1180 def setattr(self, name: str, value: Union[Tensor, 'Module']) -> None:
AttributeError: 'VITS' object has no attribute 'inference_onnx'
Hi @Fhrozen did you get a chance to look into it?
@sciai-ai, Sorry. I need a little longer time. I expect to implement the fixes for this or at last next weekend.
Hi @Fhrozen did you get a chance to work on it?
Sorry for the late response. I am checking for solutions . I will update u once finished
@Fhrozen Any updates?
Check this pls: https://github.com/Masao-Someki/espnet_onnx
@Fhrozen I also encountered this error, how did you solve it ?
onnxruntime.capi.onnxruntime_pybind11_state.RuntimeException: [ONNXRuntimeError] : 6 : RUNTIME_EXCEPTION : Non-zero status code returned while running Where node. Name:'Where_165' Status Message: /onnxruntime_src/onnxruntime/core/providers/cpu/math/element_wise_ops.h:497 void onnxruntime::BroadcastIterator::Init(ptrdiff_t, ptrdiff_t) axis == 1 || axis == largest was false. Attempting to broadcast an axis by a dimension other than 1. 8 by 19
A command for test:
The model is a single speaker VITS model. Not tested on other models.
Installation:
Issues:
Need to check some internal processes because it can generate the model, but cannot be used.
I suppose that the problem is that some variables are changed to python lists or dicts.