Created
May 3, 2022 07:30
-
-
Save taroushirani/52125daf86a745ee59354f700eed77d0 to your computer and use it in GitHub Desktop.
CLI for NNSVS packed model
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#! /usr/bin/python | |
import argparse | |
import logging | |
from nnmnkwii.io import hts | |
from nnsvs.dsp import bandpass_filter | |
from nnsvs.gen import ( | |
gen_spsvs_static_features, | |
gen_world_params, | |
postprocess_duration, | |
predict_acoustic, | |
predict_duration, | |
predict_timelag, | |
) | |
from nnsvs.svs import SPSVS | |
import numpy as np | |
import os | |
from os.path import expanduser, exists, join | |
import pyworld | |
import sys | |
from scipy.io import wavfile | |
import torch | |
from tqdm import tqdm | |
class SPSVSMod(SPSVS): | |
"""Modified version of SPSVS, Statistical Parametric Singing Voice Synthesis | |
Args: | |
model_dir (str): directory of the model | |
device (str): cpu or cuda | |
""" | |
def __init__(self, model_dir, device="cpu"): | |
super(SPSVSMod, self).__init__(model_dir, device) | |
@torch.no_grad() | |
def get_modified_labels( | |
self, | |
labels | |
): | |
"""Get timelag-and-duration modified labels | |
Args: | |
labels (nnmnkwii.io.HTSLabelFile): HTS full-context labels | |
Returns: | |
labels (nnmnkwii.io.HTSLabelFile): HTS full-context labels | |
""" | |
# Time-lag | |
lag = predict_timelag( | |
self.device, | |
labels, | |
self.timelag_model, | |
self.timelag_config, | |
self.timelag_in_scaler, | |
self.timelag_out_scaler, | |
self.binary_dict, | |
self.numeric_dict, | |
self.pitch_indices, | |
self.config.log_f0_conditioning, | |
self.config.timelag.allowed_range, | |
self.config.timelag.allowed_range_rest, | |
self.config.timelag.force_clip_input_features, | |
) | |
# Duration predictions | |
durations = predict_duration( | |
self.device, | |
labels, | |
self.duration_model, | |
self.duration_config, | |
self.duration_in_scaler, | |
self.duration_out_scaler, | |
self.binary_dict, | |
self.numeric_dict, | |
self.pitch_indices, | |
self.config.log_f0_conditioning, | |
self.config.duration.force_clip_input_features, | |
) | |
# Normalize phoneme durations | |
return postprocess_duration(labels, durations, lag) | |
@torch.no_grad() | |
def get_static_features( | |
self, | |
labels, | |
post_filter=True, | |
vuv_threshold=0.1, | |
vibrato_scale=1.0, | |
ground_truth=False | |
): | |
"""Get static features from given HTS full-context labels | |
Args: | |
labels (nnmnkwii.io.HTSLabelFile): HTS full-context labels | |
Returns: | |
numpy.ndarray: (mgc, f0, vuv, bap) | |
""" | |
if ground_truth: | |
modified_labels = labels | |
else: | |
modified_labels = self.get_modified_labels(labels) | |
acoustic_features = predict_acoustic( | |
self.device, | |
modified_labels, | |
self.acoustic_model, | |
self.acoustic_config, | |
self.acoustic_in_scaler, | |
self.acoustic_out_scaler, | |
self.binary_dict, | |
self.numeric_dict, | |
self.config.acoustic.subphone_features, | |
self.pitch_indices, | |
self.config.log_f0_conditioning, | |
self.config.acoustic.force_clip_input_features, | |
) | |
# Generate WORLD parameters | |
return gen_spsvs_static_features( | |
modified_labels, | |
acoustic_features, | |
self.binary_dict, | |
self.numeric_dict, | |
self.acoustic_config.stream_sizes, | |
self.acoustic_config.has_dynamic_features, | |
self.config.acoustic.subphone_features, | |
self.pitch_idx, | |
self.acoustic_config.num_windows, | |
post_filter, | |
self.config.sample_rate, | |
self.config.frame_period, | |
self.config.acoustic.relative_f0, | |
vibrato_scale=vibrato_scale, | |
vuv_threshold=vuv_threshold, | |
) | |
@torch.no_grad() | |
def synthesize(self, mgc, lf0, vuv, bap, vocoder_type, vuv_threshold): | |
"""Get waveform from given acoustic features. | |
Args: | |
mgc: | |
lf0: | |
vuv: | |
bap: | |
vocoder type: world or pwg | |
vuv_threshold | |
Returns: | |
numpy.ndarray: | |
""" | |
# Waveform generation by (1) WORLD or (2) neural vocoder | |
if vocoder_type == "world": | |
f0, spectrogram, aperiodicity = gen_world_params( | |
mgc, lf0, vuv, bap, self.config.sample_rate | |
) | |
wav = pyworld.synthesize( | |
f0, | |
spectrogram, | |
aperiodicity, | |
self.config.sample_rate, | |
self.config.frame_period, | |
) | |
elif vocoder_type == "pwg": | |
# NOTE: So far vocoder models are trained on binary V/UV features | |
vuv = (vuv > vuv_threshold).astype(np.float32) | |
voc_inp = ( | |
torch.from_numpy( | |
self.vocoder_in_scaler.transform( | |
np.concatenate([mgc, lf0, vuv, bap], axis=-1) | |
) | |
) | |
.float() | |
.to(self.device) | |
) | |
wav = self.vocoder.inference(voc_inp).view(-1).to("cpu").numpy() | |
wav = self.post_process(wav) | |
return wav, self.config.sample_rate | |
def _compute_nosil_duration(self, lab, threshold=5.0): | |
is_full_context = "@" in lab[0][-1] | |
sum_d = 0 | |
for s, e, label in lab: | |
d = (e - s) * 1e-7 | |
if is_full_context: | |
is_silence = "-sil" in label or "-pau" in label | |
else: | |
is_silence = label == "sil" or label == "pau" | |
if is_silence and d > threshold: | |
pass | |
else: | |
sum_d += d | |
return sum_d | |
def _is_silence(self, label): | |
is_full_context = "@" in label | |
if is_full_context: | |
is_silence = "-sil" in label or "-pau" in label | |
else: | |
is_silence = label == "sil" or label == "pau" | |
return is_silence | |
def segment_labels( | |
self, | |
lab, | |
strict=True, | |
silence_threshold=1.0, | |
min_duration=10.0, | |
force_split_threshold=30.0 | |
): | |
"""Segment labels based on sil/pau | |
Example: | |
[a b c sil d e f pau g h i sil j k l] | |
-> | |
[a b c] [d e f] [g h i] [j k l] | |
""" | |
segments = [] | |
seg = hts.HTSLabelFile() | |
start_indices = [] | |
end_indices = [] | |
si = 0 | |
large_silence_detected = False | |
for idx, (s, e, label) in enumerate(lab): | |
d = (e - s) * 1e-7 | |
is_silence = self._is_silence(label) | |
if len(seg) > 0: | |
# Compute duration except for long silences | |
seg_d = self._compute_nosil_duration(seg) | |
else: | |
seg_d = 0 | |
# let's try to split | |
# if we find large silence, force split regardless min_duration | |
if (d > force_split_threshold) or ( | |
is_silence and d > silence_threshold and seg_d > min_duration | |
): | |
if idx == len(lab) - 1: | |
continue | |
elif len(seg) > 0: | |
if d > force_split_threshold: | |
large_silence_detected = True | |
else: | |
large_silence_detected = False | |
seg.append((s, e, label), strict) | |
start_indices.append(si) | |
end_indices.append(idx) | |
segments.append(seg) | |
seg = hts.HTSLabelFile() | |
si = idx | |
seg.append((s, e, label), strict) | |
continue | |
else: | |
if len(seg) == 0: | |
si = idx | |
seg.append((s, e, label), strict) | |
if len(seg) > 0: | |
seg_d = self._compute_nosil_duration(seg) | |
# If the last segment is short, combine with the previous segment. | |
if seg_d < min_duration and not large_silence_detected: | |
end_indices[-1] = si + len(seg) - 1 | |
else: | |
start_indices.append(si) | |
end_indices.append(si + len(seg) - 1) | |
segments2 = [] | |
for s, e in zip(start_indices, end_indices): | |
seg = lab[s : e + 1] | |
offset = seg.start_times[0] | |
seg.start_times = np.asarray(seg.start_times) - offset | |
seg.end_times = np.asarray(seg.end_times) - offset | |
segments2.append(seg) | |
return segments2 | |
def fix_vuv(self, lab, vuv): | |
if len(lab.start_times) != len(set(lab.start_times)): | |
raise RuntimeError(f"Label file is not ground-truth-ed") | |
print("Fix V/UV") | |
for idx, (s, e, label) in tqdm(enumerate(lab)): | |
if "-br" in label: | |
si = np.round(s / (self.config.frame_period * 1e4)).astype(np.uint32) | |
ei = np.round(e / (self.config.frame_period * 1e4)).astype(np.uint32) | |
if ei > (vuv.shape[0] -1): | |
ei = vuv.shape[0] -1 | |
vuv[si:ei] = 0 | |
return vuv | |
def suppress_silent(self, lab, wav, sr, suppression_ratio): | |
if len(lab.start_times) != len(set(lab.start_times)): | |
raise RuntimeError(f"Label file is not ground-truth-ed") | |
print("Fix silent") | |
for idx, (s, e, label) in enumerate(lab): | |
if self._is_silence(label): | |
si = np.round(s * 1e-7 * sr).astype(np.uint32) | |
ei = np.round(e * 1e-7 * sr).astype(np.uint32) | |
if ei > (wav.shape[0] - 1): | |
ei = wav.shape[0] - 1 | |
wav[si:ei] = np.round(wav[si:ei] * (1.0 - suppression_ratio)).astype(np.int16) | |
return wav | |
def post_process(self, wav): | |
wav = bandpass_filter(wav, self.config.sample_rate) | |
return (wav / np.max(np.abs(wav)) * (2 ** 15 - 1)).astype(np.int16) | |
def get_parser(): | |
parser = argparse.ArgumentParser( | |
description='Sample script for SPSVS class' | |
) | |
parser.add_argument('model_dir', metavar='MODEL_DIR', type=str, help='packed model dir') | |
parser.add_argument('in_dir', metavar='IN_DIR', type=str, help='input directory') | |
parser.add_argument('out_dir', metavar='OUT_DIR', type=str, help='output directory') | |
parser.add_argument('--utt_list', metavar='FILE', type=str, default='song_list.txt', help='list of utterance (default: song_list.txt)') | |
parser.add_argument('--vocoder_type', metavar='TYPE', type=str, default='world', help='vocoder_type (default: world)') | |
parser.add_argument('--debug', action='store_true', help='Debug Mode') | |
parser.add_argument('--device', metavar='DEVICE', type=str, default='cuda', help='device (default: cuda)') | |
parser.add_argument('--disable_post_filter', action='store_true', help='Disable post filter') | |
parser.add_argument('--vuv_threshold', type=float, default='0.3', help='V/UV threshold(default: 0.3)') | |
parser.add_argument('--silence_threshold', type=float, default='1.0', help='Silence threshold for split(default 1.0)') | |
parser.add_argument('--min_duration', type=float, default='30.0', help='Minimum duration for split(default 30.0)') | |
parser.add_argument('--force_split_threshold', type=float, default='60.0', help='Force split threshold(default 60.0)') | |
parser.add_argument('--vibrato_scale', type=float, default='1.0', help='Vibrato scale(default: 1.0)') | |
parser.add_argument('--split', action='store_true', help='Split synthesis') | |
parser.add_argument('--return_states', action='store_true', help='Return states') | |
parser.add_argument('--fix_vuv', action='store_true', help='Fix VUV') | |
parser.add_argument('--force_whispering', action='store_true', help='Force whispering') | |
parser.add_argument('--suppress_silent', action='store_true', help='Suppress silent') | |
parser.add_argument('--suppression_ratio', type=float, default='0.8', help='Suppression ratio(default: 0.8)') | |
return parser | |
if __name__ == "__main__": | |
args = get_parser().parse_args(sys.argv[1:]) | |
post_filter = not(args.disable_post_filter) | |
if args.debug: | |
logging.basicConfig(level=logging.DEBUG) | |
logging.debug(args) | |
spsvs = SPSVSMod(args.model_dir, args.device) | |
logging.debug(spsvs) | |
in_dir = expanduser(args.in_dir) | |
out_dir = expanduser(args.out_dir) | |
os.makedirs(out_dir, exist_ok=True) | |
with open(expanduser(join(in_dir, args.utt_list))) as f: | |
lines = list(filter(lambda s: len(s.strip()) > 0, f.readlines())) | |
print(f"Processes {len(lines)} utterances...") | |
for idx in tqdm(range(len(lines))): | |
utt_id = lines[idx].strip() | |
label_path = join(in_dir, f"{utt_id}.lab") | |
if not exists(label_path): | |
raise RuntimeError(f"Label file does not exist: {label_path}") | |
# load labels and question | |
labels = hts.load(label_path).round_() | |
modified_labels = spsvs.get_modified_labels(labels) | |
# print(modified_labels) | |
if args.split: | |
segments = spsvs.segment_labels(modified_labels, | |
silence_threshold=args.silence_threshold, | |
min_duration=args.min_duration, | |
force_split_threshold=args.force_split_threshold) | |
print(f"Split into {len(segments)} segments...") | |
for idx, seg in tqdm(enumerate(segments)): | |
with open(join(out_dir, f"{utt_id}_seg{idx}.lab"), "w") as of: | |
of.write(str(seg)) | |
mgc, lf0, vuv, bap = spsvs.get_static_features( | |
seg, | |
post_filter=post_filter, | |
vuv_threshold=args.vuv_threshold, | |
vibrato_scale=args.vibrato_scale, | |
ground_truth=True | |
) | |
if args.fix_vuv: | |
vuv = spsvs.fix_vuv(seg, vuv) | |
if args.force_whispering: | |
vuv[:,None] = 0 | |
wav, sr, = spsvs.synthesize(mgc, | |
lf0, | |
vuv, | |
bap, | |
args.vocoder_type, | |
args.vuv_threshold) | |
if args.suppress_silent: | |
wav = spsvs.suppress_silent(seg, wav, sr, args.suppression_ratio) | |
out_wav_path = join(out_dir, f"{utt_id}_seg{idx}.wav") | |
wavfile.write( | |
out_wav_path, rate=sr, data=wav.astype(np.int16) | |
) | |
else: | |
mgc, lf0, vuv, bap = spsvs.get_static_features( | |
modified_labels, | |
post_filter=post_filter, | |
vuv_threshold=args.vuv_threshold, | |
vibrato_scale=args.vibrato_scale, | |
ground_truth=True | |
) | |
if args.fix_vuv: | |
vuv = spsvs.fix_vuv(modified_labels, vuv) | |
if args.force_whispering: | |
vuv[:,None] = 0 | |
wav, sr = spsvs.synthesize(mgc, | |
lf0, | |
vuv, | |
bap, | |
args.vocoder_type, | |
args.vuv_threshold) | |
if args.suppress_silent: | |
wav = spsvs.suppress_silent(modified_labels, wav, sr, args.suppression_ratio) | |
out_wav_path = join(out_dir, f"{utt_id}.wav") | |
wavfile.write( | |
out_wav_path, rate=sr, data=wav.astype(np.int16) | |
) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment