Skip to content

Instantly share code, notes, and snippets.

@taroushirani
Created November 10, 2024 15:15
Show Gist options
  • Save taroushirani/632281ea32a4baf29228593953c9fb67 to your computer and use it in GitHub Desktop.
Save taroushirani/632281ea32a4baf29228593953c9fb67 to your computer and use it in GitHub Desktop.
Modified version of SPSVS, Statistical Parametric Singing Voice Synthesis
#! /usr/bin/python
import argparse
import logging
from nnmnkwii.frontend import merlin as fe
from nnmnkwii.frontend import NOTE_MAPPING
from nnmnkwii.io import hts
from nnmnkwii.postfilters import merlin_post_filter
from nnsvs.dsp import bandpass_filter
from nnsvs.gen import (
gen_spsvs_static_features,
gen_world_params,
postprocess_duration,
predict_acoustic,
predict_duration,
predict_timelag,
)
from nnsvs.multistream import (
get_static_stream_sizes,
split_streams,
)
from nnsvs.pitch import lowpass_filter
from nnsvs.postfilters import variance_scaling
from nnsvs.svs import SPSVS
import numpy as np
import os
from os.path import expanduser, exists, join
import pysptk
import pyworld
import re
import sys
import statistics
from scipy.io import wavfile
import torch
from tqdm import tqdm
import yaml
class SPSVSMod(SPSVS):
"""Modified version of SPSVS, Statistical Parametric Singing Voice Synthesis
Args:
model_dir (str): directory of the model
device (str): cpu or cuda
"""
def __init__(self, model_dir, device="cpu"):
super(SPSVSMod, self).__init__(model_dir, device)
@torch.no_grad()
def get_modified_labels(
self,
labels
):
"""Get timelag-and-duration modified labels
Args:
labels (nnmnkwii.io.HTSLabelFile): HTS full-context labels
Returns:
labels (nnmnkwii.io.HTSLabelFile): HTS full-context labels
"""
hts_frame_shift = int(self.config.frame_period * 1e4)
labels.frame_shift = hts_frame_shift
# Time-lag
lag = predict_timelag(
self.device,
labels,
self.timelag_model,
self.timelag_config,
self.timelag_in_scaler,
self.timelag_out_scaler,
self.binary_dict,
self.numeric_dict,
self.pitch_indices,
self.config.log_f0_conditioning,
self.config.timelag.allowed_range,
self.config.timelag.allowed_range_rest,
self.config.timelag.force_clip_input_features,
)
# Duration predictions
durations = predict_duration(
self.device,
labels,
self.duration_model,
self.duration_config,
self.duration_in_scaler,
self.duration_out_scaler,
self.binary_dict,
self.numeric_dict,
self.pitch_indices,
self.config.log_f0_conditioning,
self.config.duration.force_clip_input_features,
)
# Normalize phoneme durations
return postprocess_duration(labels, durations, lag)
@torch.no_grad()
def get_static_features(
self,
labels,
post_filter_type="merlin",
gv_postfilter=True,
trajectory_smoothing=True,
trajectory_smoothing_cutoff=50,
trajectory_smoothing_cutoff_f0=20,
vuv_threshold=0.1,
vibrato_scale=1.0,
force_fix_vuv=False,
ground_truth=False
):
"""Get static features from given HTS full-context labels
Args:
labels (nnmnkwii.io.HTSLabelFile): HTS full-context labels
Returns:
numpy.ndarray: (mgc, f0, vuv, bap)
"""
hts_frame_shift = int(self.config.frame_period * 1e4)
if post_filter_type not in ["merlin", "nnsvs", "gv", "none"]:
raise ValueError(f"Unknown post-filter type: {post_filter_type}")
if ground_truth:
modified_labels = labels
else:
modified_labels = self.get_modified_labels(labels)
acoustic_features = predict_acoustic(
self.device,
modified_labels,
self.acoustic_model,
self.acoustic_config,
self.acoustic_in_scaler,
self.acoustic_out_scaler,
self.binary_dict,
self.numeric_dict,
self.config.acoustic.subphone_features,
self.pitch_indices,
self.config.log_f0_conditioning,
self.config.acoustic.force_clip_input_features,
)
static_stream_sizes = get_static_stream_sizes(
self.acoustic_config.stream_sizes,
self.acoustic_config.has_dynamic_features,
self.acoustic_config.num_windows,
)
# Apply GV post-filtering
if post_filter_type in ["nnsvs", "gv"] and gv_postfilter == True:
logging.debug("Apply GV post-filtering")
linguistic_features = fe.linguistic_features(
modified_labels,
self.binary_dict,
self.numeric_dict,
add_frame_features=True,
subphone_features=self.config.acoustic.subphone_features,
frame_shift=hts_frame_shift,
)
# TODO: remove hardcode
in_rest_idx = 0
note_frame_indices = linguistic_features[:, in_rest_idx] <= 0
mgc_end_dim = static_stream_sizes[0]
acoustic_features[:, :mgc_end_dim] = variance_scaling(
self.acoustic_out_scaler.var_.reshape(-1)[:mgc_end_dim],
acoustic_features[:, :mgc_end_dim],
offset=2,
note_frame_indices=note_frame_indices,
)
# Learned post-filter using nnsvs
if post_filter_type == "nnsvs" and self.postfilter_model is not None:
logging.debug("Apply mgc_postfilter")
# (1) Raw spectrogram or (2) mgc
rawsp_output = self.postfilter_config.stream_sizes[0] >= 128
# If the post-filter output is raw spectrogrma, convert mgc to log spectrogram
if rawsp_output:
outs = split_streams(acoustic_features, static_stream_sizes)
assert len(outs) == 4
mgc, lf0, vuv, bap = outs
fft_size = pyworld.get_cheaptrick_fft_size(self.config.sample_rate)
sp = pyworld.decode_spectral_envelope(
mgc.astype(np.float64), self.config.sample_rate, fft_size
).astype(np.float32)
sp = np.log(sp)
acoustic_features = np.concatenate([sp, lf0, vuv, bap], axis=-1)
in_feats = torch.from_numpy(acoustic_features).float().unsqueeze(0)
in_feats = self.postfilter_out_scaler.transform(in_feats).float().to(self.device)
# Run inference
out_feats = self.postfilter_model.inference(in_feats, [in_feats.shape[1]])
acoustic_features = (
self.postfilter_out_scaler.inverse_transform(out_feats.cpu())
.squeeze(0)
.numpy()
)
# Convert log spectrogram to mgc
# NOTE: mgc is used to reduce possible artifacts
# Ref: https://bit.ly/3AHjstU
if rawsp_output:
sp, lf0, vuv, bap = split_streams(
acoustic_features, self.postfilter_config.stream_sizes
)
sp = np.exp(sp)
mgc = pyworld.code_spectral_envelope(
sp.astype(np.float64), sample_rate, 60
).astype(np.float32)
acoustic_features = np.concatenate([mgc, lf0, vuv, bap], axis=-1)
# Generate WORLD parameters
mgc, lf0, vuv, bap = gen_spsvs_static_features(
modified_labels,
acoustic_features,
self.binary_dict,
self.numeric_dict,
self.acoustic_config.stream_sizes,
self.acoustic_config.has_dynamic_features,
self.config.acoustic.subphone_features,
self.pitch_idx,
self.acoustic_config.num_windows,
self.config.frame_period,
self.config.acoustic.relative_f0,
vibrato_scale=vibrato_scale,
vuv_threshold=vuv_threshold,
force_fix_vuv=force_fix_vuv
)
# NOTE: spectral enhancement based on the Merlin's post-filter implementation
if post_filter_type == "merlin":
logging.debug("Apply merlin postfilter")
alpha = pysptk.util.mcepalpha(self.config.sample_rate)
mgc = merlin_post_filter(mgc, alpha)
# Remove high-frequency components of mgc/bap
# NOTE: It seems to be effective to suppress artifacts of GAN-based post-filtering
if trajectory_smoothing:
logging.debug("Apply trajectory smoothing")
modfs = int(1 / (self.config.frame_period * 0.001))
lf0[:, 0] = lowpass_filter(
lf0[:, 0], modfs, cutoff=trajectory_smoothing_cutoff_f0
)
for d in range(mgc.shape[1]):
mgc[:, d] = lowpass_filter(
mgc[:, d], modfs, cutoff=trajectory_smoothing_cutoff
)
for d in range(bap.shape[1]):
bap[:, d] = lowpass_filter(
bap[:, d], modfs, cutoff=trajectory_smoothing_cutoff
)
return mgc, lf0, vuv, bap
@torch.no_grad()
def synthesize(self, mgc, lf0, vuv, bap, vocoder_type="world", vuv_threshold=0.1, apply_post_process=True):
"""Get waveform from given acoustic features.
Args:
mgc:
lf0:
vuv:
bap:
vocoder type: world or pwg
vuv_threshold
Returns:
numpy.ndarray:
"""
use_mcep_aperiodicity = bap.shape[-1] > 5
logging.debug(f"use_mcep_aperiodicity: {use_mcep_aperiodicity}")
if not use_mcep_aperiodicity:
bap = np.clip(bap, a_min=-60, a_max=0)
use_world_codec=self.config.get("use_world_codec", False),
logging.debug(f"use_world_codec: {use_world_codec}")
vocoder_type = vocoder_type.lower()
if vocoder_type not in ["world", "pwg", "usfgan"]:
raise ValueError(f"Unknown vocoder type: {vocoder_type}")
# Waveform generation by (1) WORLD or (2) neural vocoder
if vocoder_type == "world":
f0, spectrogram, aperiodicity = gen_world_params(
mgc,
lf0,
vuv,
bap,
self.config.sample_rate,
vuv_threshold=vuv_threshold,
use_world_codec=use_world_codec
)
wav = pyworld.synthesize(
f0,
spectrogram,
aperiodicity,
self.config.sample_rate,
self.config.frame_period,
)
elif vocoder_type == "pwg":
# NOTE: So far vocoder models are trained on binary V/UV features
vuv = (vuv > vuv_threshold).astype(np.float32)
voc_inp = (
torch.from_numpy(
self.vocoder_in_scaler.transform(
np.concatenate([mgc, lf0, vuv, bap], axis=-1)
)
)
.float()
.to(self.device)
)
wav = self.vocoder.inference(voc_inp).view(-1).to("cpu").numpy()
elif vocoder_type == "usfgan":
fftlen = pyworld.get_cheaptrick_fft_size(self.config.sample_rate)
if use_mcep_aperiodicity:
aperiodicity_order = bap.shape[-1] - 1
alpha = pysptk.util.mcepalpha(self.config.sample_rate)
aperiodicity = pysptk.mc2sp(
np.ascontiguousarray(bap).astype(np.float64),
fftlen=fftlen,
alpha=alpha,
)
else:
aperiodicity = pyworld.decode_aperiodicity(
np.ascontiguousarray(bap).astype(np.float64),
self.config.sample_rate,
fftlen,
)
# fill aperiodicity with ones for unvoiced regions
aperiodicity[vuv.reshape(-1) < vuv_threshold, 0] = 1.0
# WORLD fails catastrophically for out of range aperiodicity
aperiodicity = np.clip(aperiodicity, 0.0, 1.0)
if use_mcep_aperiodicity:
bap = pysptk.sp2mc(
aperiodicity,
order=aperiodicity_order,
alpha=alpha,
)
else:
bap = pyworld.code_aperiodicity(aperiodicity, self.config.sample_rate).astype(
np.float32
)
aux_feats = [mgc, bap]
aux_feats = (
torch.from_numpy(
self.vocoder_in_scaler.transform(np.concatenate(aux_feats, axis=-1))
)
.float()
.to(self.device)
)
contf0 = np.exp(lf0)
if self.vocoder_config.data.sine_f0_type in ["contf0", "cf0"]:
f0_inp = contf0
elif self.vocoder_config.data.sine_f0_type == "f0":
f0_inp = contf0
f0_inp[vuv < vuv_threshold] = 0
wav = self.vocoder.inference(f0_inp, aux_feats).view(-1).to("cpu").numpy()
if apply_post_process:
wav = self.post_process(wav)
return wav, self.config.sample_rate
def _compute_nosil_duration(self, lab, threshold=5.0):
is_full_context = "@" in lab[0][-1]
sum_d = 0
for s, e, label in lab:
d = (e - s) * 1e-7
if is_full_context:
is_silence = "-sil" in label or "-pau" in label
else:
is_silence = label == "sil" or label == "pau"
if is_silence and d > threshold:
pass
else:
sum_d += d
return sum_d
def _is_silence(self, label):
is_full_context = "@" in label
if is_full_context:
is_silence = "-sil" in label or "-pau" in label
else:
is_silence = label == "sil" or label == "pau"
return is_silence
def segment_labels(
self,
lab,
strict=True,
silence_threshold=1.0,
min_duration=10.0,
force_split_threshold=30.0
):
"""Segment labels based on sil/pau
Example:
[a b c sil d e f pau g h i sil j k l]
->
[a b c] [d e f] [g h i] [j k l]
"""
segments = []
seg = hts.HTSLabelFile()
start_indices = []
end_indices = []
si = 0
large_silence_detected = False
for idx, (s, e, label) in enumerate(lab):
# print(label)
d = (e - s) * 1e-7
is_silence = self._is_silence(label)
if len(seg) > 0:
# Compute duration except for long silences
seg_d = self._compute_nosil_duration(seg)
else:
seg_d = 0
# let's try to split
# if we find large silence, force split regardless min_duration
if (d > force_split_threshold) or (
is_silence and d > silence_threshold and seg_d > min_duration
):
if idx == len(lab) - 1:
continue
elif len(seg) > 0:
if d > force_split_threshold:
large_silence_detected = True
else:
large_silence_detected = False
seg.append((s, e, label), strict)
start_indices.append(si)
end_indices.append(idx)
segments.append(seg)
seg = hts.HTSLabelFile()
si = idx
seg.append((s, e, label), strict)
continue
else:
if len(seg) == 0:
si = idx
seg.append((s, e, label), strict)
if len(seg) > 0:
seg_d = self._compute_nosil_duration(seg)
# If the last segment is short, combine with the previous segment.
if seg_d < min_duration and not large_silence_detected and end_indices:
end_indices[-1] = si + len(seg) - 1
else:
start_indices.append(si)
end_indices.append(si + len(seg) - 1)
segments2 = []
for s, e in zip(start_indices, end_indices):
seg = lab[s : e + 1]
offset = seg.start_times[0]
seg.start_times = np.asarray(seg.start_times) - offset
seg.end_times = np.asarray(seg.end_times) - offset
segments2.append(seg)
return segments2
def fix_br_vuv(self, lab, vuv):
if len(lab.start_times) != len(set(lab.start_times)):
raise RuntimeError(f"Label file is not ground-truth-ed")
print("Fix V/UV")
for idx, (s, e, label) in tqdm(enumerate(lab)):
if "-br" in label:
si = np.round(s / (self.config.frame_period * 1e4)).astype(np.uint32)
ei = np.round(e / (self.config.frame_period * 1e4)).astype(np.uint32)
if ei > (vuv.shape[0] -1):
ei = vuv.shape[0] -1
vuv[si:ei] = 0
return vuv
def suppress_silent(self, lab, wav, sr, suppression_ratio):
if len(lab.start_times) != len(set(lab.start_times)):
raise RuntimeError(f"Label file is not ground-truth-ed")
print("Fix silent")
for idx, (s, e, label) in enumerate(lab):
if self._is_silence(label):
si = np.round(s * 1e-7 * sr).astype(np.uint32)
ei = np.round(e * 1e-7 * sr).astype(np.uint32)
if ei > (wav.shape[0] - 1):
ei = wav.shape[0] - 1
wav[si:ei] = np.round(wav[si:ei] * (1.0 - suppression_ratio)).astype(np.int16)
return wav
def post_process(self, wav):
logging.debug("Apply Post Process")
wav = bandpass_filter(wav, self.config.sample_rate)
if np.max(wav) > 10:
if np.abs(wav).max() > 32767:
wav = wav / np.abs(wav).max()
# data is likely already in [-32768, 32767]
wav = wav.astype(np.int16)
else:
if np.abs(wav).max() > 1.0:
wav = wav / np.abs(wav).max()
wav = (wav * 32767.0).astype(np.int16)
return wav
# return (wav / np.max(np.abs(wav)) * (2 ** 15 - 1)).astype(np.int16)
def embed_spk(self, full_labels, spk):
new_contexts = []
for _, _, context in full_labels:
for id, pre, post in [("p16", "\\]", "\\/")]:
match = re.search(f"{pre}([0-9x]+){post}", context)
if match is not None:
assert len(match.groups()) == 1
num = match.group(0)[1:-1]
if len(num) > 0:
pre = pre.replace("\\", "")
post = post.replace("\\", "")
context = context.replace(
match.group(0), f"{pre}{spk}{post}", 1
)
else:
raise RuntimeError(f"Cannot find ${id} label")
new_contexts.append(context)
new_full_labels = hts.HTSLabelFile()
new_full_labels.start_times = full_labels.start_times
new_full_labels.end_times = full_labels.end_times
new_full_labels.contexts = new_contexts
return new_full_labels
def convert_bpm(labels, source_bpm, target_bpm, autodetect_source_bpm=True):
if autodetect_source_bpm:
bpm = []
e5re1 = re.compile('^.+~(\d+)!.+$')
for i in range(len(labels)):
matched = re.match(e5re1, labels.contexts[i])
bpm.append(int(matched[1]))
_source_bpm = statistics.mode(bpm)
logging.debug(f"auto-detected source bpm: {_source_bpm}")
elif source_bpm != 0:
_source_bpm = source_bpm
else:
raise RuntimeError(f"Invalid source BPM or autodetect_source_bpm is disabled")
if target_bpm != 0:
_target_bpm = int(target_bpm)
else:
raise RuntimeError(f"Target BPM is not specified")
convert_ratio = float(_target_bpm / _source_bpm)
logging.debug(f"convert_ratio: {convert_ratio}")
regexp = re.compile('(?P<p1_d4>^.+%)(?P<d5>\w+)(?P<d6_e4>\|.+~)(?P<e5>\w+)(?P<e6>!.+@)(?P<e7>\w+)(?P<e8_e11>#.*\|)(?P<e12>\w+)(?P<e12_e13>\[)(?P<e13>\w+)(?P<e14_e19>&.*_)(?P<e20>\w+)(?P<e20_e21>;)(?P<e21>\w+)(?P<e22_f4>\$.*\$)(?P<f5>\w+)(?P<f6_j3>\$.*$)')
for i in range(len(labels)):
labels.start_times[i] = round(labels.start_times[i] / convert_ratio / 50000) * 50000
labels.end_times[i] = round(labels.end_times[i] / convert_ratio / 50000) * 50000
m = re.match(regexp, labels.contexts[i])
# logging.debug(f"p1_d4: {m.group('p1_d4')}")
new_d5 = str(round(int(m.group('d5')) * convert_ratio)) if m.group('d5') != 'xx' else m.group('d5')
new_e5 = str(round(int(m.group('e5')) * convert_ratio)) if m.group('e5') != 'xx' else m.group('e5')
new_e7 = str(round(int(m.group('e7')) / convert_ratio)) if m.group('e7') != 'xx' else m.group('e7')
new_e12 = str(round(int(m.group('e12')) / convert_ratio)) if m.group('e12') != 'xx' else m.group('e12')
new_e13 = str(round(int(m.group('e13')) / convert_ratio)) if m.group('e13') != 'xx' else m.group('e13')
new_e20 = str(round(int(m.group('e20')) / convert_ratio)) if m.group('e20') != 'xx' else m.group('e20')
new_e21 = str(round(int(m.group('e21')) / convert_ratio)) if m.group('e21') != 'xx' else m.group('e21')
new_f5 = str(round(int(m.group('f5')) * convert_ratio)) if m.group('f5') != 'xx' else m.group('f5')
new_context = m.group('p1_d4') + new_d5 + m.group('d6_e4') + new_e5 + m.group('e6') + \
new_e7 + m.group('e8_e11') + new_e12 + m.group('e12_e13') + new_e13 + m.group('e14_e19') + \
new_e20 + m.group('e20_e21') + new_e21 + m.group('e22_f4') + new_f5 + m.group('f6_j3')
assert abs(len(labels.contexts[i]) - len(new_context)) < 10
# logging.debug(f"new_context: {new_context}")
labels.contexts[i] = new_context
return labels, _source_bpm, _target_bpm
def get_shifted_midi_name(midi_name, pitch_shift):
if pitch_shift == 0:
return midi_name
if midi_name == 'xx':
return midi_name
assert midi_name in NOTE_MAPPING, f"Failed to lookup {midi_name} from NOTE_MAPPING"
new_midi_name = [k for k, v in NOTE_MAPPING.items() if v == NOTE_MAPPING[midi_name] + pitch_shift]
assert len(new_midi_name) == 1, f"Failed to shift {midi_name} by {pitch_shift}"
return new_midi_name[0]
def pitch_shift(labels, pitch_shift):
regexp = re.compile('(?P<p1_d1>^.+/D:)(?P<d1>\w+)(?P<d1_e1>!.+/E:)(?P<e1>\w+)(?P<e1_f1>\].+/F:)(?P<f1>\w+)(?P<f1_j3>#.+$)')
for i in range(len(labels)):
m = re.match(regexp, labels.contexts[i])
new_d1 = get_shifted_midi_name(m.group('d1'), pitch_shift)
new_e1 = get_shifted_midi_name(m.group('e1'), pitch_shift)
new_f1 = get_shifted_midi_name(m.group('f1'), pitch_shift)
new_context = m.group('p1_d1') + new_d1 + m.group('d1_e1') + new_e1 + m.group('e1_f1') + \
new_f1 + m.group('f1_j3')
# logging.debug(f"new_context: {new_context}")
labels.contexts[i] = new_context
return labels
def force_restricted_ph_duration(labels, ph_stats, restricted_ph_level):
s = labels.frame_shift
bad_phs = ["n", "m", "y", "s", "sh"]
logging.debug(f"restricted_ph_lebels: {restricted_ph_level}")
regexp = re.compile('(?P<p1_p3>^.+-)(?P<p4>\w+)(?P<p5_j3>\+.*$)')
for i in range(len(labels) -1):
# logging.debug(f"contexts: {labels.contexts[i]}")
m = re.match(regexp, labels.contexts[i])
ph = m.group('p4')
logging.debug(f"p4: {ph}")
if ph == "sil":
logging.debug("skipping")
continue
if restricted_ph_level == "1sd":
dur_ul = np.int64((float(ph_stats[ph]["mean"]) + float(ph_stats[ph]["std"])) * 1e7)
elif restricted_ph_level == "2sd":
dur_ul = np.int64((float(ph_stats[ph]["mean"]) + float(ph_stats[ph]["std"])*2) * 1e7)
elif restricted_ph_level == "max":
dur_ul = np.int64(float(ph_stats[ph]["max"]) * 1e7)
else:
raise RuntimeError(f"Wrong restricted_ph_level: {restricted_ph_level}")
logging.debug(f"dur_ul: {dur_ul}")
if labels.end_times[i] - labels.start_times[i] > dur_ul:
logging.debug(f"The duration of phoneme: {ph} at {labels.start_times[i]} exceeds dul_ul: {dur_ul}.")
if ph not in bad_phs:
logging.debug("skipping")
continue
logging.debug(f"Original start and end times: {labels.start_times[i]} {labels.end_times[i]}")
labels.end_times[i] = labels.start_times[i] + np.round(dur_ul / s).astype(np.int64) * s
labels.start_times[i+1] = labels.end_times[i]
logging.debug(f"Modified start and end times: {labels.start_times[i]} {labels.end_times[i]}")
logging.debug(f"Next phoneme starts at {labels.start_times[i+1]}")
return labels
def get_parser():
parser = argparse.ArgumentParser(
description='Sample script for SPSVS class'
)
parser.add_argument('model_dir', metavar='MODEL_DIR', type=str, help='packed model dir')
parser.add_argument('in_dir', metavar='IN_DIR', type=str, help='input directory')
parser.add_argument('out_dir', metavar='OUT_DIR', type=str, help='output directory')
parser.add_argument('--utt_list', metavar='FILE', type=str, default='song_list.txt', help='list of utterance (default: song_list.txt)')
parser.add_argument('--vocoder_type', metavar='TYPE', type=str, default='world', help='vocoder type (default: world)')
parser.add_argument('--post_filter_type', metavar='TYPE', type=str, default='merlin', help='Post filter type')
parser.add_argument('--disable_gv_postfilter', action='store_true', help='Disalbe GV postfilter')
parser.add_argument('--disable_trajectory_smoothing', action='store_true', help='Disalbe Trajectory Smoothing')
parser.add_argument('--trajectory_smoothing_cutoff', type=int, default=50, help='Trajectory Smoothing Cutoff')
parser.add_argument('--trajectory_smoothing_cutoff_f0', type=int, default=20, help='Trajectory Smoothing Cutoff F0')
parser.add_argument('--debug', action='store_true', help='Debug Mode')
parser.add_argument('--device', metavar='DEVICE', type=str, default='cuda', help='device (default: cuda)')
parser.add_argument('--vuv_threshold', type=float, default='0.3', help='V/UV threshold(default: 0.3)')
parser.add_argument('--silence_threshold', type=float, default='1.0', help='Silence threshold for split(default 1.0)')
parser.add_argument('--min_duration', type=float, default='30.0', help='Minimum duration for split(default 30.0)')
parser.add_argument('--force_split_threshold', type=float, default='60.0', help='Force split threshold(default 60.0)')
parser.add_argument('--vibrato_scale', type=float, default='1.0', help='Vibrato scale(default: 1.0)')
parser.add_argument('--split', action='store_true', help='Split synthesis')
parser.add_argument('--spk', type=int, default='0', help='SPK identifier(default: 0)')
parser.add_argument('--return_states', action='store_true', help='Return states')
parser.add_argument('--force_fix_vuv', action='store_true', help='Enable Force Fix VUV Feature')
parser.add_argument('--fix_br_vuv', action='store_true', help='Fix BR VUV')
parser.add_argument('--force_whispering', action='store_true', help='Force whispering')
parser.add_argument('--suppress_silent', action='store_true', help='Suppress silent')
parser.add_argument('--suppression_ratio', type=float, default='0.8', help='Suppression ratio(default: 0.8)')
parser.add_argument('--disable_post_process', action='store_true', help='Disalbe Post Process')
parser.add_argument('--convert_bpm', action='store_true', help='Convert BPM')
parser.add_argument('--autodetect_source_bpm', action='store_true', help='Detect source BPM automatically')
parser.add_argument('--source_bpm', type=int, default=0, help='Source BPM')
parser.add_argument('--target_bpm', type=int, default=0, help='Target BPM')
parser.add_argument('--pitch_shift', type=int, default=0, help='Pitch shift')
parser.add_argument('--ph_stats_yaml', type=str, default='', help='Phoneme statistics yaml file')
parser.add_argument('--restricted_ph_level', type=str, default='none', help='Restriction level of phoneme duration(none, 1sd, 2sd, max)')
return parser
if __name__ == "__main__":
args = get_parser().parse_args(sys.argv[1:])
if args.debug:
logging.basicConfig(level=logging.DEBUG)
logging.debug(args)
trajectory_smoothing=not(args.disable_trajectory_smoothing)
gv_postfilter=not(args.disable_gv_postfilter)
apply_post_process=not(args.disable_post_process)
spsvs = SPSVSMod(args.model_dir, args.device)
logging.debug(spsvs)
in_dir = expanduser(args.in_dir)
out_dir = expanduser(args.out_dir)
os.makedirs(out_dir, exist_ok=True)
postfix=""
with open(expanduser(join(in_dir, args.utt_list))) as f:
lines = list(filter(lambda s: len(s.strip()) > 0, f.readlines()))
print(f"Processes {len(lines)} utterances...")
for idx in tqdm(range(len(lines))):
utt_id = lines[idx].strip()
label_path = join(in_dir, f"{utt_id}.lab")
if not exists(label_path):
raise RuntimeError(f"Label file does not exist: {label_path}")
# load labels and question
labels = hts.load(label_path).round_()
modified_labels = spsvs.get_modified_labels(labels)
# print(modified_labels)
# Embed spk
modified_labels = spsvs.embed_spk(modified_labels, args.spk)
if args.convert_bpm:
modified_labels, source_bpm, target_bpm = convert_bpm(modified_labels, args.source_bpm, args.target_bpm, args.autodetect_source_bpm)
postfix += f"_convert_from_{source_bpm}_to_{target_bpm}"
if args.pitch_shift != 0:
modified_labels = pitch_shift(modified_labels, args.pitch_shift)
if args.pitch_shift > 0:
sign = "p"
else:
sign = "m"
postfix += f"_pitch_shift_{sign}{abs(args.pitch_shift)}"
if args.restricted_ph_level != "none":
ph_stats = {}
with open(args.ph_stats_yaml, 'r', encoding='utf-8') as yml:
ph_stats = yaml.load(yml, Loader=yaml.BaseLoader)
if len(ph_stats) == 0:
raise RuntimeError(f"Wrong phoneme statistics yaml file {args.ph_stats_yaml}.")
modified_labels = force_restricted_ph_duration(modified_labels, ph_stats, args.restricted_ph_level)
postfix += f"_force_restricted_ph_{args.restricted_ph_level}"
if args.split:
segments = spsvs.segment_labels(modified_labels,
silence_threshold=args.silence_threshold,
min_duration=args.min_duration,
force_split_threshold=args.force_split_threshold)
print(f"Split into {len(segments)} segments...")
for idx, seg in tqdm(enumerate(segments)):
with open(join(out_dir, f"{utt_id}{postfix}_spk{args.spk}_seg{idx}.lab"), "w") as of:
of.write(str(seg))
mgc, lf0, vuv, bap = spsvs.get_static_features(
seg,
args.post_filter_type,
gv_postfilter,
trajectory_smoothing,
args.trajectory_smoothing_cutoff,
vuv_threshold=args.vuv_threshold,
vibrato_scale=args.vibrato_scale,
force_fix_vuv=args.force_fix_vuv,
ground_truth=True
)
if args.fix_br_vuv:
vuv = spsvs.fix_br_vuv(seg, vuv)
if args.force_whispering:
vuv[:,None] = 0
wav, sr, = spsvs.synthesize(mgc,
lf0,
vuv,
bap,
args.vocoder_type,
args.vuv_threshold,
apply_post_process)
if args.suppress_silent:
wav = spsvs.suppress_silent(seg, wav, sr, args.suppression_ratio)
out_wav_path = join(out_dir, f"{utt_id}{postfix}_spk{args.spk}_seg{idx}.wav")
wavfile.write(
out_wav_path, rate=sr, data=wav.astype(np.int16)
)
else:
with open(join(out_dir, f"{utt_id}{postfix}_spk{args.spk}.lab"), "w") as of:
of.write(str(modified_labels))
mgc, lf0, vuv, bap = spsvs.get_static_features(
modified_labels,
args.post_filter_type,
gv_postfilter,
trajectory_smoothing,
args.trajectory_smoothing_cutoff,
vuv_threshold=args.vuv_threshold,
vibrato_scale=args.vibrato_scale,
ground_truth=True
)
if args.fix_br_vuv:
vuv = spsvs.fix_br_vuv(modified_labels, vuv)
if args.force_whispering:
vuv[:,None] = 0
wav, sr = spsvs.synthesize(mgc,
lf0,
vuv,
bap,
args.vocoder_type,
args.vuv_threshold,
apply_post_process)
if args.suppress_silent:
wav = spsvs.suppress_silent(modified_labels, wav, sr, args.suppression_ratio)
out_wav_path = join(out_dir, f"{utt_id}{postfix}_spk{args.spk}.wav")
wavfile.write(
out_wav_path, rate=sr, data=wav.astype(np.int16)
)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment