Created
April 18, 2022 13:10
-
-
Save taroushirani/bbcab5f016a31b619465bf2e0b309c5a to your computer and use it in GitHub Desktop.
Sample script for SPSVS class
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#! /usr/bin/python | |
import argparse | |
import logging | |
from nnmnkwii.io import hts | |
from nnsvs.svs import SPSVS | |
import numpy as np | |
import os | |
from os.path import expanduser, exists, join | |
import sys | |
from scipy.io import wavfile | |
from tqdm import tqdm | |
def get_parser(): | |
parser = argparse.ArgumentParser( | |
description='Sample script for SPSVS class' | |
) | |
parser.add_argument('model_dir', metavar='DIR', type=str, help='packed model dir') | |
parser.add_argument('in_dir', metavar='DIR', type=str, help='input directory') | |
parser.add_argument('out_dir', metavar='DIR', type=str, help='output directory') | |
parser.add_argument('--utt_list', metavar='FILE', type=str, default='song_list.txt', help='list of utterance (default: song_list.txt)') | |
parser.add_argument('--vocoder_type', metavar='TYPE', type=str, default='world', help='vocoder_type (default: world)') | |
parser.add_argument('--debug', action='store_true', help='Debug Mode') | |
parser.add_argument('--device', metavar='DEVICE', type=str, default='cuda', help='device (default: cuda)') | |
parser.add_argument('--disable_post_filter', action='store_true', help='Disable post filter') | |
parser.add_argument('--vuv_threshold', type=float, default='0.3', help='V/UV threshold(default: 0.3)') | |
parser.add_argument('--vibrato_scale', type=float, default='1.0', help='Vibrato scale(default: 1.0)') | |
parser.add_argument('--return_states', action='store_true', help='Return states') | |
return parser | |
if __name__ == "__main__": | |
args = get_parser().parse_args(sys.argv[1:]) | |
post_filter = not(args.disable_post_filter) | |
if args.debug: | |
logging.basicConfig(level=logging.DEBUG) | |
logging.debug(args) | |
spsvs = SPSVS(args.model_dir, args.device) | |
logging.debug(spsvs) | |
in_dir = expanduser(args.in_dir) | |
out_dir = expanduser(args.out_dir) | |
os.makedirs(out_dir, exist_ok=True) | |
with open(expanduser(join(in_dir, args.utt_list))) as f: | |
lines = list(filter(lambda s: len(s.strip()) > 0, f.readlines())) | |
print(f"Processes {len(lines)} utterances...") | |
for idx in tqdm(range(len(lines))): | |
utt_id = lines[idx].strip() | |
label_path = join(in_dir, f"{utt_id}.lab") | |
if not exists(label_path): | |
raise RuntimeError(f"Label file does not exist: {label_path}") | |
# load labels and question | |
labels = hts.load(label_path).round_() | |
if args.return_states: | |
wav, sr, states = spsvs.svs(labels, args.vocoder_type, post_filter, args.vuv_threshold, args.vibrato_scale, args.return_states) | |
mgc, lf0, vuv, bap = states["mgc"], states["lf0"], states["vuv"], states["bap"] | |
static_feats = np.hstack((mgc, lf0, vuv, bap)).astype(np.float32) | |
static_path = join(out_dir, f"{utt_id}-feats.npy") | |
np.save(static_path, static_feats, allow_pickle=False) | |
logging.debug(f"Save static features to {static_path}") | |
else: | |
wav, sr, = spsvs.svs(labels, args.vocoder_type, post_filter, args.vuv_threshold, args.vibrato_scale, args.return_states) | |
out_wav_path = join(out_dir, f"{utt_id}.wav") | |
wavfile.write( | |
out_wav_path, rate=sr, data=wav.astype(np.int16) | |
) | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment