Skip to content

Instantly share code, notes, and snippets.

@taroushirani
Created May 3, 2022 07:30
Show Gist options
  • Save taroushirani/52125daf86a745ee59354f700eed77d0 to your computer and use it in GitHub Desktop.
Save taroushirani/52125daf86a745ee59354f700eed77d0 to your computer and use it in GitHub Desktop.
CLI for NNSVS packed model
#! /usr/bin/python
import argparse
import logging
from nnmnkwii.io import hts
from nnsvs.dsp import bandpass_filter
from nnsvs.gen import (
gen_spsvs_static_features,
gen_world_params,
postprocess_duration,
predict_acoustic,
predict_duration,
predict_timelag,
)
from nnsvs.svs import SPSVS
import numpy as np
import os
from os.path import expanduser, exists, join
import pyworld
import sys
from scipy.io import wavfile
import torch
from tqdm import tqdm
class SPSVSMod(SPSVS):
"""Modified version of SPSVS, Statistical Parametric Singing Voice Synthesis
Args:
model_dir (str): directory of the model
device (str): cpu or cuda
"""
def __init__(self, model_dir, device="cpu"):
super(SPSVSMod, self).__init__(model_dir, device)
@torch.no_grad()
def get_modified_labels(
self,
labels
):
"""Get timelag-and-duration modified labels
Args:
labels (nnmnkwii.io.HTSLabelFile): HTS full-context labels
Returns:
labels (nnmnkwii.io.HTSLabelFile): HTS full-context labels
"""
# Time-lag
lag = predict_timelag(
self.device,
labels,
self.timelag_model,
self.timelag_config,
self.timelag_in_scaler,
self.timelag_out_scaler,
self.binary_dict,
self.numeric_dict,
self.pitch_indices,
self.config.log_f0_conditioning,
self.config.timelag.allowed_range,
self.config.timelag.allowed_range_rest,
self.config.timelag.force_clip_input_features,
)
# Duration predictions
durations = predict_duration(
self.device,
labels,
self.duration_model,
self.duration_config,
self.duration_in_scaler,
self.duration_out_scaler,
self.binary_dict,
self.numeric_dict,
self.pitch_indices,
self.config.log_f0_conditioning,
self.config.duration.force_clip_input_features,
)
# Normalize phoneme durations
return postprocess_duration(labels, durations, lag)
@torch.no_grad()
def get_static_features(
self,
labels,
post_filter=True,
vuv_threshold=0.1,
vibrato_scale=1.0,
ground_truth=False
):
"""Get static features from given HTS full-context labels
Args:
labels (nnmnkwii.io.HTSLabelFile): HTS full-context labels
Returns:
numpy.ndarray: (mgc, f0, vuv, bap)
"""
if ground_truth:
modified_labels = labels
else:
modified_labels = self.get_modified_labels(labels)
acoustic_features = predict_acoustic(
self.device,
modified_labels,
self.acoustic_model,
self.acoustic_config,
self.acoustic_in_scaler,
self.acoustic_out_scaler,
self.binary_dict,
self.numeric_dict,
self.config.acoustic.subphone_features,
self.pitch_indices,
self.config.log_f0_conditioning,
self.config.acoustic.force_clip_input_features,
)
# Generate WORLD parameters
return gen_spsvs_static_features(
modified_labels,
acoustic_features,
self.binary_dict,
self.numeric_dict,
self.acoustic_config.stream_sizes,
self.acoustic_config.has_dynamic_features,
self.config.acoustic.subphone_features,
self.pitch_idx,
self.acoustic_config.num_windows,
post_filter,
self.config.sample_rate,
self.config.frame_period,
self.config.acoustic.relative_f0,
vibrato_scale=vibrato_scale,
vuv_threshold=vuv_threshold,
)
@torch.no_grad()
def synthesize(self, mgc, lf0, vuv, bap, vocoder_type, vuv_threshold):
"""Get waveform from given acoustic features.
Args:
mgc:
lf0:
vuv:
bap:
vocoder type: world or pwg
vuv_threshold
Returns:
numpy.ndarray:
"""
# Waveform generation by (1) WORLD or (2) neural vocoder
if vocoder_type == "world":
f0, spectrogram, aperiodicity = gen_world_params(
mgc, lf0, vuv, bap, self.config.sample_rate
)
wav = pyworld.synthesize(
f0,
spectrogram,
aperiodicity,
self.config.sample_rate,
self.config.frame_period,
)
elif vocoder_type == "pwg":
# NOTE: So far vocoder models are trained on binary V/UV features
vuv = (vuv > vuv_threshold).astype(np.float32)
voc_inp = (
torch.from_numpy(
self.vocoder_in_scaler.transform(
np.concatenate([mgc, lf0, vuv, bap], axis=-1)
)
)
.float()
.to(self.device)
)
wav = self.vocoder.inference(voc_inp).view(-1).to("cpu").numpy()
wav = self.post_process(wav)
return wav, self.config.sample_rate
def _compute_nosil_duration(self, lab, threshold=5.0):
is_full_context = "@" in lab[0][-1]
sum_d = 0
for s, e, label in lab:
d = (e - s) * 1e-7
if is_full_context:
is_silence = "-sil" in label or "-pau" in label
else:
is_silence = label == "sil" or label == "pau"
if is_silence and d > threshold:
pass
else:
sum_d += d
return sum_d
def _is_silence(self, label):
is_full_context = "@" in label
if is_full_context:
is_silence = "-sil" in label or "-pau" in label
else:
is_silence = label == "sil" or label == "pau"
return is_silence
def segment_labels(
self,
lab,
strict=True,
silence_threshold=1.0,
min_duration=10.0,
force_split_threshold=30.0
):
"""Segment labels based on sil/pau
Example:
[a b c sil d e f pau g h i sil j k l]
->
[a b c] [d e f] [g h i] [j k l]
"""
segments = []
seg = hts.HTSLabelFile()
start_indices = []
end_indices = []
si = 0
large_silence_detected = False
for idx, (s, e, label) in enumerate(lab):
d = (e - s) * 1e-7
is_silence = self._is_silence(label)
if len(seg) > 0:
# Compute duration except for long silences
seg_d = self._compute_nosil_duration(seg)
else:
seg_d = 0
# let's try to split
# if we find large silence, force split regardless min_duration
if (d > force_split_threshold) or (
is_silence and d > silence_threshold and seg_d > min_duration
):
if idx == len(lab) - 1:
continue
elif len(seg) > 0:
if d > force_split_threshold:
large_silence_detected = True
else:
large_silence_detected = False
seg.append((s, e, label), strict)
start_indices.append(si)
end_indices.append(idx)
segments.append(seg)
seg = hts.HTSLabelFile()
si = idx
seg.append((s, e, label), strict)
continue
else:
if len(seg) == 0:
si = idx
seg.append((s, e, label), strict)
if len(seg) > 0:
seg_d = self._compute_nosil_duration(seg)
# If the last segment is short, combine with the previous segment.
if seg_d < min_duration and not large_silence_detected:
end_indices[-1] = si + len(seg) - 1
else:
start_indices.append(si)
end_indices.append(si + len(seg) - 1)
segments2 = []
for s, e in zip(start_indices, end_indices):
seg = lab[s : e + 1]
offset = seg.start_times[0]
seg.start_times = np.asarray(seg.start_times) - offset
seg.end_times = np.asarray(seg.end_times) - offset
segments2.append(seg)
return segments2
def fix_vuv(self, lab, vuv):
if len(lab.start_times) != len(set(lab.start_times)):
raise RuntimeError(f"Label file is not ground-truth-ed")
print("Fix V/UV")
for idx, (s, e, label) in tqdm(enumerate(lab)):
if "-br" in label:
si = np.round(s / (self.config.frame_period * 1e4)).astype(np.uint32)
ei = np.round(e / (self.config.frame_period * 1e4)).astype(np.uint32)
if ei > (vuv.shape[0] -1):
ei = vuv.shape[0] -1
vuv[si:ei] = 0
return vuv
def suppress_silent(self, lab, wav, sr, suppression_ratio):
if len(lab.start_times) != len(set(lab.start_times)):
raise RuntimeError(f"Label file is not ground-truth-ed")
print("Fix silent")
for idx, (s, e, label) in enumerate(lab):
if self._is_silence(label):
si = np.round(s * 1e-7 * sr).astype(np.uint32)
ei = np.round(e * 1e-7 * sr).astype(np.uint32)
if ei > (wav.shape[0] - 1):
ei = wav.shape[0] - 1
wav[si:ei] = np.round(wav[si:ei] * (1.0 - suppression_ratio)).astype(np.int16)
return wav
def post_process(self, wav):
wav = bandpass_filter(wav, self.config.sample_rate)
return (wav / np.max(np.abs(wav)) * (2 ** 15 - 1)).astype(np.int16)
def get_parser():
parser = argparse.ArgumentParser(
description='Sample script for SPSVS class'
)
parser.add_argument('model_dir', metavar='MODEL_DIR', type=str, help='packed model dir')
parser.add_argument('in_dir', metavar='IN_DIR', type=str, help='input directory')
parser.add_argument('out_dir', metavar='OUT_DIR', type=str, help='output directory')
parser.add_argument('--utt_list', metavar='FILE', type=str, default='song_list.txt', help='list of utterance (default: song_list.txt)')
parser.add_argument('--vocoder_type', metavar='TYPE', type=str, default='world', help='vocoder_type (default: world)')
parser.add_argument('--debug', action='store_true', help='Debug Mode')
parser.add_argument('--device', metavar='DEVICE', type=str, default='cuda', help='device (default: cuda)')
parser.add_argument('--disable_post_filter', action='store_true', help='Disable post filter')
parser.add_argument('--vuv_threshold', type=float, default='0.3', help='V/UV threshold(default: 0.3)')
parser.add_argument('--silence_threshold', type=float, default='1.0', help='Silence threshold for split(default 1.0)')
parser.add_argument('--min_duration', type=float, default='30.0', help='Minimum duration for split(default 30.0)')
parser.add_argument('--force_split_threshold', type=float, default='60.0', help='Force split threshold(default 60.0)')
parser.add_argument('--vibrato_scale', type=float, default='1.0', help='Vibrato scale(default: 1.0)')
parser.add_argument('--split', action='store_true', help='Split synthesis')
parser.add_argument('--return_states', action='store_true', help='Return states')
parser.add_argument('--fix_vuv', action='store_true', help='Fix VUV')
parser.add_argument('--force_whispering', action='store_true', help='Force whispering')
parser.add_argument('--suppress_silent', action='store_true', help='Suppress silent')
parser.add_argument('--suppression_ratio', type=float, default='0.8', help='Suppression ratio(default: 0.8)')
return parser
if __name__ == "__main__":
args = get_parser().parse_args(sys.argv[1:])
post_filter = not(args.disable_post_filter)
if args.debug:
logging.basicConfig(level=logging.DEBUG)
logging.debug(args)
spsvs = SPSVSMod(args.model_dir, args.device)
logging.debug(spsvs)
in_dir = expanduser(args.in_dir)
out_dir = expanduser(args.out_dir)
os.makedirs(out_dir, exist_ok=True)
with open(expanduser(join(in_dir, args.utt_list))) as f:
lines = list(filter(lambda s: len(s.strip()) > 0, f.readlines()))
print(f"Processes {len(lines)} utterances...")
for idx in tqdm(range(len(lines))):
utt_id = lines[idx].strip()
label_path = join(in_dir, f"{utt_id}.lab")
if not exists(label_path):
raise RuntimeError(f"Label file does not exist: {label_path}")
# load labels and question
labels = hts.load(label_path).round_()
modified_labels = spsvs.get_modified_labels(labels)
# print(modified_labels)
if args.split:
segments = spsvs.segment_labels(modified_labels,
silence_threshold=args.silence_threshold,
min_duration=args.min_duration,
force_split_threshold=args.force_split_threshold)
print(f"Split into {len(segments)} segments...")
for idx, seg in tqdm(enumerate(segments)):
with open(join(out_dir, f"{utt_id}_seg{idx}.lab"), "w") as of:
of.write(str(seg))
mgc, lf0, vuv, bap = spsvs.get_static_features(
seg,
post_filter=post_filter,
vuv_threshold=args.vuv_threshold,
vibrato_scale=args.vibrato_scale,
ground_truth=True
)
if args.fix_vuv:
vuv = spsvs.fix_vuv(seg, vuv)
if args.force_whispering:
vuv[:,None] = 0
wav, sr, = spsvs.synthesize(mgc,
lf0,
vuv,
bap,
args.vocoder_type,
args.vuv_threshold)
if args.suppress_silent:
wav = spsvs.suppress_silent(seg, wav, sr, args.suppression_ratio)
out_wav_path = join(out_dir, f"{utt_id}_seg{idx}.wav")
wavfile.write(
out_wav_path, rate=sr, data=wav.astype(np.int16)
)
else:
mgc, lf0, vuv, bap = spsvs.get_static_features(
modified_labels,
post_filter=post_filter,
vuv_threshold=args.vuv_threshold,
vibrato_scale=args.vibrato_scale,
ground_truth=True
)
if args.fix_vuv:
vuv = spsvs.fix_vuv(modified_labels, vuv)
if args.force_whispering:
vuv[:,None] = 0
wav, sr = spsvs.synthesize(mgc,
lf0,
vuv,
bap,
args.vocoder_type,
args.vuv_threshold)
if args.suppress_silent:
wav = spsvs.suppress_silent(modified_labels, wav, sr, args.suppression_ratio)
out_wav_path = join(out_dir, f"{utt_id}.wav")
wavfile.write(
out_wav_path, rate=sr, data=wav.astype(np.int16)
)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment