Created
November 10, 2024 15:15
-
-
Save taroushirani/632281ea32a4baf29228593953c9fb67 to your computer and use it in GitHub Desktop.
Modified version of SPSVS, Statistical Parametric Singing Voice Synthesis
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#! /usr/bin/python | |
import argparse | |
import logging | |
from nnmnkwii.frontend import merlin as fe | |
from nnmnkwii.frontend import NOTE_MAPPING | |
from nnmnkwii.io import hts | |
from nnmnkwii.postfilters import merlin_post_filter | |
from nnsvs.dsp import bandpass_filter | |
from nnsvs.gen import ( | |
gen_spsvs_static_features, | |
gen_world_params, | |
postprocess_duration, | |
predict_acoustic, | |
predict_duration, | |
predict_timelag, | |
) | |
from nnsvs.multistream import ( | |
get_static_stream_sizes, | |
split_streams, | |
) | |
from nnsvs.pitch import lowpass_filter | |
from nnsvs.postfilters import variance_scaling | |
from nnsvs.svs import SPSVS | |
import numpy as np | |
import os | |
from os.path import expanduser, exists, join | |
import pysptk | |
import pyworld | |
import re | |
import sys | |
import statistics | |
from scipy.io import wavfile | |
import torch | |
from tqdm import tqdm | |
import yaml | |
class SPSVSMod(SPSVS): | |
"""Modified version of SPSVS, Statistical Parametric Singing Voice Synthesis | |
Args: | |
model_dir (str): directory of the model | |
device (str): cpu or cuda | |
""" | |
def __init__(self, model_dir, device="cpu"): | |
super(SPSVSMod, self).__init__(model_dir, device) | |
@torch.no_grad() | |
def get_modified_labels( | |
self, | |
labels | |
): | |
"""Get timelag-and-duration modified labels | |
Args: | |
labels (nnmnkwii.io.HTSLabelFile): HTS full-context labels | |
Returns: | |
labels (nnmnkwii.io.HTSLabelFile): HTS full-context labels | |
""" | |
hts_frame_shift = int(self.config.frame_period * 1e4) | |
labels.frame_shift = hts_frame_shift | |
# Time-lag | |
lag = predict_timelag( | |
self.device, | |
labels, | |
self.timelag_model, | |
self.timelag_config, | |
self.timelag_in_scaler, | |
self.timelag_out_scaler, | |
self.binary_dict, | |
self.numeric_dict, | |
self.pitch_indices, | |
self.config.log_f0_conditioning, | |
self.config.timelag.allowed_range, | |
self.config.timelag.allowed_range_rest, | |
self.config.timelag.force_clip_input_features, | |
) | |
# Duration predictions | |
durations = predict_duration( | |
self.device, | |
labels, | |
self.duration_model, | |
self.duration_config, | |
self.duration_in_scaler, | |
self.duration_out_scaler, | |
self.binary_dict, | |
self.numeric_dict, | |
self.pitch_indices, | |
self.config.log_f0_conditioning, | |
self.config.duration.force_clip_input_features, | |
) | |
# Normalize phoneme durations | |
return postprocess_duration(labels, durations, lag) | |
@torch.no_grad() | |
def get_static_features( | |
self, | |
labels, | |
post_filter_type="merlin", | |
gv_postfilter=True, | |
trajectory_smoothing=True, | |
trajectory_smoothing_cutoff=50, | |
trajectory_smoothing_cutoff_f0=20, | |
vuv_threshold=0.1, | |
vibrato_scale=1.0, | |
force_fix_vuv=False, | |
ground_truth=False | |
): | |
"""Get static features from given HTS full-context labels | |
Args: | |
labels (nnmnkwii.io.HTSLabelFile): HTS full-context labels | |
Returns: | |
numpy.ndarray: (mgc, f0, vuv, bap) | |
""" | |
hts_frame_shift = int(self.config.frame_period * 1e4) | |
if post_filter_type not in ["merlin", "nnsvs", "gv", "none"]: | |
raise ValueError(f"Unknown post-filter type: {post_filter_type}") | |
if ground_truth: | |
modified_labels = labels | |
else: | |
modified_labels = self.get_modified_labels(labels) | |
acoustic_features = predict_acoustic( | |
self.device, | |
modified_labels, | |
self.acoustic_model, | |
self.acoustic_config, | |
self.acoustic_in_scaler, | |
self.acoustic_out_scaler, | |
self.binary_dict, | |
self.numeric_dict, | |
self.config.acoustic.subphone_features, | |
self.pitch_indices, | |
self.config.log_f0_conditioning, | |
self.config.acoustic.force_clip_input_features, | |
) | |
static_stream_sizes = get_static_stream_sizes( | |
self.acoustic_config.stream_sizes, | |
self.acoustic_config.has_dynamic_features, | |
self.acoustic_config.num_windows, | |
) | |
# Apply GV post-filtering | |
if post_filter_type in ["nnsvs", "gv"] and gv_postfilter == True: | |
logging.debug("Apply GV post-filtering") | |
linguistic_features = fe.linguistic_features( | |
modified_labels, | |
self.binary_dict, | |
self.numeric_dict, | |
add_frame_features=True, | |
subphone_features=self.config.acoustic.subphone_features, | |
frame_shift=hts_frame_shift, | |
) | |
# TODO: remove hardcode | |
in_rest_idx = 0 | |
note_frame_indices = linguistic_features[:, in_rest_idx] <= 0 | |
mgc_end_dim = static_stream_sizes[0] | |
acoustic_features[:, :mgc_end_dim] = variance_scaling( | |
self.acoustic_out_scaler.var_.reshape(-1)[:mgc_end_dim], | |
acoustic_features[:, :mgc_end_dim], | |
offset=2, | |
note_frame_indices=note_frame_indices, | |
) | |
# Learned post-filter using nnsvs | |
if post_filter_type == "nnsvs" and self.postfilter_model is not None: | |
logging.debug("Apply mgc_postfilter") | |
# (1) Raw spectrogram or (2) mgc | |
rawsp_output = self.postfilter_config.stream_sizes[0] >= 128 | |
# If the post-filter output is raw spectrogrma, convert mgc to log spectrogram | |
if rawsp_output: | |
outs = split_streams(acoustic_features, static_stream_sizes) | |
assert len(outs) == 4 | |
mgc, lf0, vuv, bap = outs | |
fft_size = pyworld.get_cheaptrick_fft_size(self.config.sample_rate) | |
sp = pyworld.decode_spectral_envelope( | |
mgc.astype(np.float64), self.config.sample_rate, fft_size | |
).astype(np.float32) | |
sp = np.log(sp) | |
acoustic_features = np.concatenate([sp, lf0, vuv, bap], axis=-1) | |
in_feats = torch.from_numpy(acoustic_features).float().unsqueeze(0) | |
in_feats = self.postfilter_out_scaler.transform(in_feats).float().to(self.device) | |
# Run inference | |
out_feats = self.postfilter_model.inference(in_feats, [in_feats.shape[1]]) | |
acoustic_features = ( | |
self.postfilter_out_scaler.inverse_transform(out_feats.cpu()) | |
.squeeze(0) | |
.numpy() | |
) | |
# Convert log spectrogram to mgc | |
# NOTE: mgc is used to reduce possible artifacts | |
# Ref: https://bit.ly/3AHjstU | |
if rawsp_output: | |
sp, lf0, vuv, bap = split_streams( | |
acoustic_features, self.postfilter_config.stream_sizes | |
) | |
sp = np.exp(sp) | |
mgc = pyworld.code_spectral_envelope( | |
sp.astype(np.float64), sample_rate, 60 | |
).astype(np.float32) | |
acoustic_features = np.concatenate([mgc, lf0, vuv, bap], axis=-1) | |
# Generate WORLD parameters | |
mgc, lf0, vuv, bap = gen_spsvs_static_features( | |
modified_labels, | |
acoustic_features, | |
self.binary_dict, | |
self.numeric_dict, | |
self.acoustic_config.stream_sizes, | |
self.acoustic_config.has_dynamic_features, | |
self.config.acoustic.subphone_features, | |
self.pitch_idx, | |
self.acoustic_config.num_windows, | |
self.config.frame_period, | |
self.config.acoustic.relative_f0, | |
vibrato_scale=vibrato_scale, | |
vuv_threshold=vuv_threshold, | |
force_fix_vuv=force_fix_vuv | |
) | |
# NOTE: spectral enhancement based on the Merlin's post-filter implementation | |
if post_filter_type == "merlin": | |
logging.debug("Apply merlin postfilter") | |
alpha = pysptk.util.mcepalpha(self.config.sample_rate) | |
mgc = merlin_post_filter(mgc, alpha) | |
# Remove high-frequency components of mgc/bap | |
# NOTE: It seems to be effective to suppress artifacts of GAN-based post-filtering | |
if trajectory_smoothing: | |
logging.debug("Apply trajectory smoothing") | |
modfs = int(1 / (self.config.frame_period * 0.001)) | |
lf0[:, 0] = lowpass_filter( | |
lf0[:, 0], modfs, cutoff=trajectory_smoothing_cutoff_f0 | |
) | |
for d in range(mgc.shape[1]): | |
mgc[:, d] = lowpass_filter( | |
mgc[:, d], modfs, cutoff=trajectory_smoothing_cutoff | |
) | |
for d in range(bap.shape[1]): | |
bap[:, d] = lowpass_filter( | |
bap[:, d], modfs, cutoff=trajectory_smoothing_cutoff | |
) | |
return mgc, lf0, vuv, bap | |
@torch.no_grad() | |
def synthesize(self, mgc, lf0, vuv, bap, vocoder_type="world", vuv_threshold=0.1, apply_post_process=True): | |
"""Get waveform from given acoustic features. | |
Args: | |
mgc: | |
lf0: | |
vuv: | |
bap: | |
vocoder type: world or pwg | |
vuv_threshold | |
Returns: | |
numpy.ndarray: | |
""" | |
use_mcep_aperiodicity = bap.shape[-1] > 5 | |
logging.debug(f"use_mcep_aperiodicity: {use_mcep_aperiodicity}") | |
if not use_mcep_aperiodicity: | |
bap = np.clip(bap, a_min=-60, a_max=0) | |
use_world_codec=self.config.get("use_world_codec", False), | |
logging.debug(f"use_world_codec: {use_world_codec}") | |
vocoder_type = vocoder_type.lower() | |
if vocoder_type not in ["world", "pwg", "usfgan"]: | |
raise ValueError(f"Unknown vocoder type: {vocoder_type}") | |
# Waveform generation by (1) WORLD or (2) neural vocoder | |
if vocoder_type == "world": | |
f0, spectrogram, aperiodicity = gen_world_params( | |
mgc, | |
lf0, | |
vuv, | |
bap, | |
self.config.sample_rate, | |
vuv_threshold=vuv_threshold, | |
use_world_codec=use_world_codec | |
) | |
wav = pyworld.synthesize( | |
f0, | |
spectrogram, | |
aperiodicity, | |
self.config.sample_rate, | |
self.config.frame_period, | |
) | |
elif vocoder_type == "pwg": | |
# NOTE: So far vocoder models are trained on binary V/UV features | |
vuv = (vuv > vuv_threshold).astype(np.float32) | |
voc_inp = ( | |
torch.from_numpy( | |
self.vocoder_in_scaler.transform( | |
np.concatenate([mgc, lf0, vuv, bap], axis=-1) | |
) | |
) | |
.float() | |
.to(self.device) | |
) | |
wav = self.vocoder.inference(voc_inp).view(-1).to("cpu").numpy() | |
elif vocoder_type == "usfgan": | |
fftlen = pyworld.get_cheaptrick_fft_size(self.config.sample_rate) | |
if use_mcep_aperiodicity: | |
aperiodicity_order = bap.shape[-1] - 1 | |
alpha = pysptk.util.mcepalpha(self.config.sample_rate) | |
aperiodicity = pysptk.mc2sp( | |
np.ascontiguousarray(bap).astype(np.float64), | |
fftlen=fftlen, | |
alpha=alpha, | |
) | |
else: | |
aperiodicity = pyworld.decode_aperiodicity( | |
np.ascontiguousarray(bap).astype(np.float64), | |
self.config.sample_rate, | |
fftlen, | |
) | |
# fill aperiodicity with ones for unvoiced regions | |
aperiodicity[vuv.reshape(-1) < vuv_threshold, 0] = 1.0 | |
# WORLD fails catastrophically for out of range aperiodicity | |
aperiodicity = np.clip(aperiodicity, 0.0, 1.0) | |
if use_mcep_aperiodicity: | |
bap = pysptk.sp2mc( | |
aperiodicity, | |
order=aperiodicity_order, | |
alpha=alpha, | |
) | |
else: | |
bap = pyworld.code_aperiodicity(aperiodicity, self.config.sample_rate).astype( | |
np.float32 | |
) | |
aux_feats = [mgc, bap] | |
aux_feats = ( | |
torch.from_numpy( | |
self.vocoder_in_scaler.transform(np.concatenate(aux_feats, axis=-1)) | |
) | |
.float() | |
.to(self.device) | |
) | |
contf0 = np.exp(lf0) | |
if self.vocoder_config.data.sine_f0_type in ["contf0", "cf0"]: | |
f0_inp = contf0 | |
elif self.vocoder_config.data.sine_f0_type == "f0": | |
f0_inp = contf0 | |
f0_inp[vuv < vuv_threshold] = 0 | |
wav = self.vocoder.inference(f0_inp, aux_feats).view(-1).to("cpu").numpy() | |
if apply_post_process: | |
wav = self.post_process(wav) | |
return wav, self.config.sample_rate | |
def _compute_nosil_duration(self, lab, threshold=5.0): | |
is_full_context = "@" in lab[0][-1] | |
sum_d = 0 | |
for s, e, label in lab: | |
d = (e - s) * 1e-7 | |
if is_full_context: | |
is_silence = "-sil" in label or "-pau" in label | |
else: | |
is_silence = label == "sil" or label == "pau" | |
if is_silence and d > threshold: | |
pass | |
else: | |
sum_d += d | |
return sum_d | |
def _is_silence(self, label): | |
is_full_context = "@" in label | |
if is_full_context: | |
is_silence = "-sil" in label or "-pau" in label | |
else: | |
is_silence = label == "sil" or label == "pau" | |
return is_silence | |
def segment_labels( | |
self, | |
lab, | |
strict=True, | |
silence_threshold=1.0, | |
min_duration=10.0, | |
force_split_threshold=30.0 | |
): | |
"""Segment labels based on sil/pau | |
Example: | |
[a b c sil d e f pau g h i sil j k l] | |
-> | |
[a b c] [d e f] [g h i] [j k l] | |
""" | |
segments = [] | |
seg = hts.HTSLabelFile() | |
start_indices = [] | |
end_indices = [] | |
si = 0 | |
large_silence_detected = False | |
for idx, (s, e, label) in enumerate(lab): | |
# print(label) | |
d = (e - s) * 1e-7 | |
is_silence = self._is_silence(label) | |
if len(seg) > 0: | |
# Compute duration except for long silences | |
seg_d = self._compute_nosil_duration(seg) | |
else: | |
seg_d = 0 | |
# let's try to split | |
# if we find large silence, force split regardless min_duration | |
if (d > force_split_threshold) or ( | |
is_silence and d > silence_threshold and seg_d > min_duration | |
): | |
if idx == len(lab) - 1: | |
continue | |
elif len(seg) > 0: | |
if d > force_split_threshold: | |
large_silence_detected = True | |
else: | |
large_silence_detected = False | |
seg.append((s, e, label), strict) | |
start_indices.append(si) | |
end_indices.append(idx) | |
segments.append(seg) | |
seg = hts.HTSLabelFile() | |
si = idx | |
seg.append((s, e, label), strict) | |
continue | |
else: | |
if len(seg) == 0: | |
si = idx | |
seg.append((s, e, label), strict) | |
if len(seg) > 0: | |
seg_d = self._compute_nosil_duration(seg) | |
# If the last segment is short, combine with the previous segment. | |
if seg_d < min_duration and not large_silence_detected and end_indices: | |
end_indices[-1] = si + len(seg) - 1 | |
else: | |
start_indices.append(si) | |
end_indices.append(si + len(seg) - 1) | |
segments2 = [] | |
for s, e in zip(start_indices, end_indices): | |
seg = lab[s : e + 1] | |
offset = seg.start_times[0] | |
seg.start_times = np.asarray(seg.start_times) - offset | |
seg.end_times = np.asarray(seg.end_times) - offset | |
segments2.append(seg) | |
return segments2 | |
def fix_br_vuv(self, lab, vuv): | |
if len(lab.start_times) != len(set(lab.start_times)): | |
raise RuntimeError(f"Label file is not ground-truth-ed") | |
print("Fix V/UV") | |
for idx, (s, e, label) in tqdm(enumerate(lab)): | |
if "-br" in label: | |
si = np.round(s / (self.config.frame_period * 1e4)).astype(np.uint32) | |
ei = np.round(e / (self.config.frame_period * 1e4)).astype(np.uint32) | |
if ei > (vuv.shape[0] -1): | |
ei = vuv.shape[0] -1 | |
vuv[si:ei] = 0 | |
return vuv | |
def suppress_silent(self, lab, wav, sr, suppression_ratio): | |
if len(lab.start_times) != len(set(lab.start_times)): | |
raise RuntimeError(f"Label file is not ground-truth-ed") | |
print("Fix silent") | |
for idx, (s, e, label) in enumerate(lab): | |
if self._is_silence(label): | |
si = np.round(s * 1e-7 * sr).astype(np.uint32) | |
ei = np.round(e * 1e-7 * sr).astype(np.uint32) | |
if ei > (wav.shape[0] - 1): | |
ei = wav.shape[0] - 1 | |
wav[si:ei] = np.round(wav[si:ei] * (1.0 - suppression_ratio)).astype(np.int16) | |
return wav | |
def post_process(self, wav): | |
logging.debug("Apply Post Process") | |
wav = bandpass_filter(wav, self.config.sample_rate) | |
if np.max(wav) > 10: | |
if np.abs(wav).max() > 32767: | |
wav = wav / np.abs(wav).max() | |
# data is likely already in [-32768, 32767] | |
wav = wav.astype(np.int16) | |
else: | |
if np.abs(wav).max() > 1.0: | |
wav = wav / np.abs(wav).max() | |
wav = (wav * 32767.0).astype(np.int16) | |
return wav | |
# return (wav / np.max(np.abs(wav)) * (2 ** 15 - 1)).astype(np.int16) | |
def embed_spk(self, full_labels, spk): | |
new_contexts = [] | |
for _, _, context in full_labels: | |
for id, pre, post in [("p16", "\\]", "\\/")]: | |
match = re.search(f"{pre}([0-9x]+){post}", context) | |
if match is not None: | |
assert len(match.groups()) == 1 | |
num = match.group(0)[1:-1] | |
if len(num) > 0: | |
pre = pre.replace("\\", "") | |
post = post.replace("\\", "") | |
context = context.replace( | |
match.group(0), f"{pre}{spk}{post}", 1 | |
) | |
else: | |
raise RuntimeError(f"Cannot find ${id} label") | |
new_contexts.append(context) | |
new_full_labels = hts.HTSLabelFile() | |
new_full_labels.start_times = full_labels.start_times | |
new_full_labels.end_times = full_labels.end_times | |
new_full_labels.contexts = new_contexts | |
return new_full_labels | |
def convert_bpm(labels, source_bpm, target_bpm, autodetect_source_bpm=True): | |
if autodetect_source_bpm: | |
bpm = [] | |
e5re1 = re.compile('^.+~(\d+)!.+$') | |
for i in range(len(labels)): | |
matched = re.match(e5re1, labels.contexts[i]) | |
bpm.append(int(matched[1])) | |
_source_bpm = statistics.mode(bpm) | |
logging.debug(f"auto-detected source bpm: {_source_bpm}") | |
elif source_bpm != 0: | |
_source_bpm = source_bpm | |
else: | |
raise RuntimeError(f"Invalid source BPM or autodetect_source_bpm is disabled") | |
if target_bpm != 0: | |
_target_bpm = int(target_bpm) | |
else: | |
raise RuntimeError(f"Target BPM is not specified") | |
convert_ratio = float(_target_bpm / _source_bpm) | |
logging.debug(f"convert_ratio: {convert_ratio}") | |
regexp = re.compile('(?P<p1_d4>^.+%)(?P<d5>\w+)(?P<d6_e4>\|.+~)(?P<e5>\w+)(?P<e6>!.+@)(?P<e7>\w+)(?P<e8_e11>#.*\|)(?P<e12>\w+)(?P<e12_e13>\[)(?P<e13>\w+)(?P<e14_e19>&.*_)(?P<e20>\w+)(?P<e20_e21>;)(?P<e21>\w+)(?P<e22_f4>\$.*\$)(?P<f5>\w+)(?P<f6_j3>\$.*$)') | |
for i in range(len(labels)): | |
labels.start_times[i] = round(labels.start_times[i] / convert_ratio / 50000) * 50000 | |
labels.end_times[i] = round(labels.end_times[i] / convert_ratio / 50000) * 50000 | |
m = re.match(regexp, labels.contexts[i]) | |
# logging.debug(f"p1_d4: {m.group('p1_d4')}") | |
new_d5 = str(round(int(m.group('d5')) * convert_ratio)) if m.group('d5') != 'xx' else m.group('d5') | |
new_e5 = str(round(int(m.group('e5')) * convert_ratio)) if m.group('e5') != 'xx' else m.group('e5') | |
new_e7 = str(round(int(m.group('e7')) / convert_ratio)) if m.group('e7') != 'xx' else m.group('e7') | |
new_e12 = str(round(int(m.group('e12')) / convert_ratio)) if m.group('e12') != 'xx' else m.group('e12') | |
new_e13 = str(round(int(m.group('e13')) / convert_ratio)) if m.group('e13') != 'xx' else m.group('e13') | |
new_e20 = str(round(int(m.group('e20')) / convert_ratio)) if m.group('e20') != 'xx' else m.group('e20') | |
new_e21 = str(round(int(m.group('e21')) / convert_ratio)) if m.group('e21') != 'xx' else m.group('e21') | |
new_f5 = str(round(int(m.group('f5')) * convert_ratio)) if m.group('f5') != 'xx' else m.group('f5') | |
new_context = m.group('p1_d4') + new_d5 + m.group('d6_e4') + new_e5 + m.group('e6') + \ | |
new_e7 + m.group('e8_e11') + new_e12 + m.group('e12_e13') + new_e13 + m.group('e14_e19') + \ | |
new_e20 + m.group('e20_e21') + new_e21 + m.group('e22_f4') + new_f5 + m.group('f6_j3') | |
assert abs(len(labels.contexts[i]) - len(new_context)) < 10 | |
# logging.debug(f"new_context: {new_context}") | |
labels.contexts[i] = new_context | |
return labels, _source_bpm, _target_bpm | |
def get_shifted_midi_name(midi_name, pitch_shift): | |
if pitch_shift == 0: | |
return midi_name | |
if midi_name == 'xx': | |
return midi_name | |
assert midi_name in NOTE_MAPPING, f"Failed to lookup {midi_name} from NOTE_MAPPING" | |
new_midi_name = [k for k, v in NOTE_MAPPING.items() if v == NOTE_MAPPING[midi_name] + pitch_shift] | |
assert len(new_midi_name) == 1, f"Failed to shift {midi_name} by {pitch_shift}" | |
return new_midi_name[0] | |
def pitch_shift(labels, pitch_shift): | |
regexp = re.compile('(?P<p1_d1>^.+/D:)(?P<d1>\w+)(?P<d1_e1>!.+/E:)(?P<e1>\w+)(?P<e1_f1>\].+/F:)(?P<f1>\w+)(?P<f1_j3>#.+$)') | |
for i in range(len(labels)): | |
m = re.match(regexp, labels.contexts[i]) | |
new_d1 = get_shifted_midi_name(m.group('d1'), pitch_shift) | |
new_e1 = get_shifted_midi_name(m.group('e1'), pitch_shift) | |
new_f1 = get_shifted_midi_name(m.group('f1'), pitch_shift) | |
new_context = m.group('p1_d1') + new_d1 + m.group('d1_e1') + new_e1 + m.group('e1_f1') + \ | |
new_f1 + m.group('f1_j3') | |
# logging.debug(f"new_context: {new_context}") | |
labels.contexts[i] = new_context | |
return labels | |
def force_restricted_ph_duration(labels, ph_stats, restricted_ph_level): | |
s = labels.frame_shift | |
bad_phs = ["n", "m", "y", "s", "sh"] | |
logging.debug(f"restricted_ph_lebels: {restricted_ph_level}") | |
regexp = re.compile('(?P<p1_p3>^.+-)(?P<p4>\w+)(?P<p5_j3>\+.*$)') | |
for i in range(len(labels) -1): | |
# logging.debug(f"contexts: {labels.contexts[i]}") | |
m = re.match(regexp, labels.contexts[i]) | |
ph = m.group('p4') | |
logging.debug(f"p4: {ph}") | |
if ph == "sil": | |
logging.debug("skipping") | |
continue | |
if restricted_ph_level == "1sd": | |
dur_ul = np.int64((float(ph_stats[ph]["mean"]) + float(ph_stats[ph]["std"])) * 1e7) | |
elif restricted_ph_level == "2sd": | |
dur_ul = np.int64((float(ph_stats[ph]["mean"]) + float(ph_stats[ph]["std"])*2) * 1e7) | |
elif restricted_ph_level == "max": | |
dur_ul = np.int64(float(ph_stats[ph]["max"]) * 1e7) | |
else: | |
raise RuntimeError(f"Wrong restricted_ph_level: {restricted_ph_level}") | |
logging.debug(f"dur_ul: {dur_ul}") | |
if labels.end_times[i] - labels.start_times[i] > dur_ul: | |
logging.debug(f"The duration of phoneme: {ph} at {labels.start_times[i]} exceeds dul_ul: {dur_ul}.") | |
if ph not in bad_phs: | |
logging.debug("skipping") | |
continue | |
logging.debug(f"Original start and end times: {labels.start_times[i]} {labels.end_times[i]}") | |
labels.end_times[i] = labels.start_times[i] + np.round(dur_ul / s).astype(np.int64) * s | |
labels.start_times[i+1] = labels.end_times[i] | |
logging.debug(f"Modified start and end times: {labels.start_times[i]} {labels.end_times[i]}") | |
logging.debug(f"Next phoneme starts at {labels.start_times[i+1]}") | |
return labels | |
def get_parser(): | |
parser = argparse.ArgumentParser( | |
description='Sample script for SPSVS class' | |
) | |
parser.add_argument('model_dir', metavar='MODEL_DIR', type=str, help='packed model dir') | |
parser.add_argument('in_dir', metavar='IN_DIR', type=str, help='input directory') | |
parser.add_argument('out_dir', metavar='OUT_DIR', type=str, help='output directory') | |
parser.add_argument('--utt_list', metavar='FILE', type=str, default='song_list.txt', help='list of utterance (default: song_list.txt)') | |
parser.add_argument('--vocoder_type', metavar='TYPE', type=str, default='world', help='vocoder type (default: world)') | |
parser.add_argument('--post_filter_type', metavar='TYPE', type=str, default='merlin', help='Post filter type') | |
parser.add_argument('--disable_gv_postfilter', action='store_true', help='Disalbe GV postfilter') | |
parser.add_argument('--disable_trajectory_smoothing', action='store_true', help='Disalbe Trajectory Smoothing') | |
parser.add_argument('--trajectory_smoothing_cutoff', type=int, default=50, help='Trajectory Smoothing Cutoff') | |
parser.add_argument('--trajectory_smoothing_cutoff_f0', type=int, default=20, help='Trajectory Smoothing Cutoff F0') | |
parser.add_argument('--debug', action='store_true', help='Debug Mode') | |
parser.add_argument('--device', metavar='DEVICE', type=str, default='cuda', help='device (default: cuda)') | |
parser.add_argument('--vuv_threshold', type=float, default='0.3', help='V/UV threshold(default: 0.3)') | |
parser.add_argument('--silence_threshold', type=float, default='1.0', help='Silence threshold for split(default 1.0)') | |
parser.add_argument('--min_duration', type=float, default='30.0', help='Minimum duration for split(default 30.0)') | |
parser.add_argument('--force_split_threshold', type=float, default='60.0', help='Force split threshold(default 60.0)') | |
parser.add_argument('--vibrato_scale', type=float, default='1.0', help='Vibrato scale(default: 1.0)') | |
parser.add_argument('--split', action='store_true', help='Split synthesis') | |
parser.add_argument('--spk', type=int, default='0', help='SPK identifier(default: 0)') | |
parser.add_argument('--return_states', action='store_true', help='Return states') | |
parser.add_argument('--force_fix_vuv', action='store_true', help='Enable Force Fix VUV Feature') | |
parser.add_argument('--fix_br_vuv', action='store_true', help='Fix BR VUV') | |
parser.add_argument('--force_whispering', action='store_true', help='Force whispering') | |
parser.add_argument('--suppress_silent', action='store_true', help='Suppress silent') | |
parser.add_argument('--suppression_ratio', type=float, default='0.8', help='Suppression ratio(default: 0.8)') | |
parser.add_argument('--disable_post_process', action='store_true', help='Disalbe Post Process') | |
parser.add_argument('--convert_bpm', action='store_true', help='Convert BPM') | |
parser.add_argument('--autodetect_source_bpm', action='store_true', help='Detect source BPM automatically') | |
parser.add_argument('--source_bpm', type=int, default=0, help='Source BPM') | |
parser.add_argument('--target_bpm', type=int, default=0, help='Target BPM') | |
parser.add_argument('--pitch_shift', type=int, default=0, help='Pitch shift') | |
parser.add_argument('--ph_stats_yaml', type=str, default='', help='Phoneme statistics yaml file') | |
parser.add_argument('--restricted_ph_level', type=str, default='none', help='Restriction level of phoneme duration(none, 1sd, 2sd, max)') | |
return parser | |
if __name__ == "__main__": | |
args = get_parser().parse_args(sys.argv[1:]) | |
if args.debug: | |
logging.basicConfig(level=logging.DEBUG) | |
logging.debug(args) | |
trajectory_smoothing=not(args.disable_trajectory_smoothing) | |
gv_postfilter=not(args.disable_gv_postfilter) | |
apply_post_process=not(args.disable_post_process) | |
spsvs = SPSVSMod(args.model_dir, args.device) | |
logging.debug(spsvs) | |
in_dir = expanduser(args.in_dir) | |
out_dir = expanduser(args.out_dir) | |
os.makedirs(out_dir, exist_ok=True) | |
postfix="" | |
with open(expanduser(join(in_dir, args.utt_list))) as f: | |
lines = list(filter(lambda s: len(s.strip()) > 0, f.readlines())) | |
print(f"Processes {len(lines)} utterances...") | |
for idx in tqdm(range(len(lines))): | |
utt_id = lines[idx].strip() | |
label_path = join(in_dir, f"{utt_id}.lab") | |
if not exists(label_path): | |
raise RuntimeError(f"Label file does not exist: {label_path}") | |
# load labels and question | |
labels = hts.load(label_path).round_() | |
modified_labels = spsvs.get_modified_labels(labels) | |
# print(modified_labels) | |
# Embed spk | |
modified_labels = spsvs.embed_spk(modified_labels, args.spk) | |
if args.convert_bpm: | |
modified_labels, source_bpm, target_bpm = convert_bpm(modified_labels, args.source_bpm, args.target_bpm, args.autodetect_source_bpm) | |
postfix += f"_convert_from_{source_bpm}_to_{target_bpm}" | |
if args.pitch_shift != 0: | |
modified_labels = pitch_shift(modified_labels, args.pitch_shift) | |
if args.pitch_shift > 0: | |
sign = "p" | |
else: | |
sign = "m" | |
postfix += f"_pitch_shift_{sign}{abs(args.pitch_shift)}" | |
if args.restricted_ph_level != "none": | |
ph_stats = {} | |
with open(args.ph_stats_yaml, 'r', encoding='utf-8') as yml: | |
ph_stats = yaml.load(yml, Loader=yaml.BaseLoader) | |
if len(ph_stats) == 0: | |
raise RuntimeError(f"Wrong phoneme statistics yaml file {args.ph_stats_yaml}.") | |
modified_labels = force_restricted_ph_duration(modified_labels, ph_stats, args.restricted_ph_level) | |
postfix += f"_force_restricted_ph_{args.restricted_ph_level}" | |
if args.split: | |
segments = spsvs.segment_labels(modified_labels, | |
silence_threshold=args.silence_threshold, | |
min_duration=args.min_duration, | |
force_split_threshold=args.force_split_threshold) | |
print(f"Split into {len(segments)} segments...") | |
for idx, seg in tqdm(enumerate(segments)): | |
with open(join(out_dir, f"{utt_id}{postfix}_spk{args.spk}_seg{idx}.lab"), "w") as of: | |
of.write(str(seg)) | |
mgc, lf0, vuv, bap = spsvs.get_static_features( | |
seg, | |
args.post_filter_type, | |
gv_postfilter, | |
trajectory_smoothing, | |
args.trajectory_smoothing_cutoff, | |
vuv_threshold=args.vuv_threshold, | |
vibrato_scale=args.vibrato_scale, | |
force_fix_vuv=args.force_fix_vuv, | |
ground_truth=True | |
) | |
if args.fix_br_vuv: | |
vuv = spsvs.fix_br_vuv(seg, vuv) | |
if args.force_whispering: | |
vuv[:,None] = 0 | |
wav, sr, = spsvs.synthesize(mgc, | |
lf0, | |
vuv, | |
bap, | |
args.vocoder_type, | |
args.vuv_threshold, | |
apply_post_process) | |
if args.suppress_silent: | |
wav = spsvs.suppress_silent(seg, wav, sr, args.suppression_ratio) | |
out_wav_path = join(out_dir, f"{utt_id}{postfix}_spk{args.spk}_seg{idx}.wav") | |
wavfile.write( | |
out_wav_path, rate=sr, data=wav.astype(np.int16) | |
) | |
else: | |
with open(join(out_dir, f"{utt_id}{postfix}_spk{args.spk}.lab"), "w") as of: | |
of.write(str(modified_labels)) | |
mgc, lf0, vuv, bap = spsvs.get_static_features( | |
modified_labels, | |
args.post_filter_type, | |
gv_postfilter, | |
trajectory_smoothing, | |
args.trajectory_smoothing_cutoff, | |
vuv_threshold=args.vuv_threshold, | |
vibrato_scale=args.vibrato_scale, | |
ground_truth=True | |
) | |
if args.fix_br_vuv: | |
vuv = spsvs.fix_br_vuv(modified_labels, vuv) | |
if args.force_whispering: | |
vuv[:,None] = 0 | |
wav, sr = spsvs.synthesize(mgc, | |
lf0, | |
vuv, | |
bap, | |
args.vocoder_type, | |
args.vuv_threshold, | |
apply_post_process) | |
if args.suppress_silent: | |
wav = spsvs.suppress_silent(modified_labels, wav, sr, args.suppression_ratio) | |
out_wav_path = join(out_dir, f"{utt_id}{postfix}_spk{args.spk}.wav") | |
wavfile.write( | |
out_wav_path, rate=sr, data=wav.astype(np.int16) | |
) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment