Created
January 4, 2024 14:39
-
-
Save taroushirani/bfba9ac5dc7bfe8dd403868e213e5187 to your computer and use it in GitHub Desktop.
Tempo-shift data augmentation with preserved consonant duration
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#! /usr/bin/python | |
import argparse | |
from glob import glob | |
import logging | |
import os | |
from os.path import join, basename, splitext | |
import re | |
import sys | |
from tqdm import tqdm | |
import librosa | |
import soundfile as sf | |
import numpy as np | |
import pyrubberband as pyrb | |
from nnmnkwii.io import hts | |
from nnsvs.io.hts import get_note_indices | |
def get_parser(): | |
parser = argparse.ArgumentParser( | |
description="Consonant-invariant tempo-shift data augmentation", | |
formatter_class=argparse.ArgumentDefaultsHelpFormatter, | |
) | |
parser.add_argument("src_dir", type=str, help="Source directory") | |
parser.add_argument("dest_dir", type=str, help="Destination directory") | |
parser.add_argument("scale", type=float, help="Scale of tempo conversion") | |
parser.add_argument('--debug', action='store_true', help='Debug Mode') | |
return parser | |
def _is_full_context(labels): | |
assert(isinstance(labels, hts.HTSLabelFile)) | |
return "@" in labels[0][-1] | |
def _is_vowel(phoneme): | |
return phoneme in ["a", "i", "u", "e", "o", "N", "I", "U"] | |
def _is_pau(phoneme): | |
return phoneme == "pau" | |
def _is_br(phoneme): | |
return phoneme == "br" | |
def _is_special(phoneme): | |
# "cl", "br" | |
return phoneme == "cl" or _is_br(phoneme) | |
def _is_consonant(phoneme): | |
return not _is_vowel(phoneme) and not _is_pau(phoneme) and not _is_special(phoneme) | |
def _has_vowel_reduction(phonemes): | |
ret = True | |
for phoneme in phonemes: | |
if _is_pau(phoneme) or _is_vowel(phoneme): | |
ret = False | |
break | |
return ret | |
def _convert_mono_labels(mono_labels, note_indices, scale, tolerance=5): | |
logging.debug(f"scale: {scale}") | |
logging.debug(f"mono_labels.contexts: {mono_labels.contexts}") | |
logging.debug(f"phoneme of note_indices: {[mono_labels.contexts[i] for i in note_indices]}") | |
new_mono_labels = hts.HTSLabelFile() | |
for idx in range(len(note_indices)): | |
note_index = note_indices[idx] | |
logging.debug(f"idx: {idx}, note_index: {note_index}") | |
if note_index == len(mono_labels) -1: | |
logging.debug("Last pau") | |
note_duration = mono_labels.end_times[note_index] - mono_labels.start_times[note_index] | |
new_mono_labels.append([new_mono_labels.end_times[-1], \ | |
new_mono_labels.end_times[-1] + int(note_duration * scale), \ | |
mono_labels.contexts[note_index]], strict=False) | |
else: | |
phoneme_num_in_note = note_indices[idx+1]-note_index | |
logging.debug(f"range(note_index, phoneme_num_in_note): {range(note_index, note_index+phoneme_num_in_note)}") | |
phonemes_in_note = [mono_labels.contexts[i] for i in range(note_index, note_index+phoneme_num_in_note)] | |
logging.debug(f"phoneme in notes: {phonemes_in_note}") | |
if _has_vowel_reduction(phonemes_in_note): | |
raise RuntimeError(f"Vowel reduction detected: {phonemes_in_note}") | |
note_duration = mono_labels.end_times[note_index + phoneme_num_in_note -1] - mono_labels.start_times[note_index] | |
residue = int(note_duration * scale) | |
logging.debug(f"note_duration: {note_duration}") | |
for pos in range(phoneme_num_in_note): | |
logging.debug(f"note_index+pos: {note_index+pos}") | |
if _is_pau(mono_labels.contexts[note_index+pos]): | |
if note_index == 0: | |
logging.debug("First 'pau'") | |
start_time = 0 | |
else: | |
start_time = new_mono_labels.end_times[-1] | |
if phoneme_num_in_note == 1: | |
logging.debug("Current note consits of only 'pau'") | |
new_mono_labels.append([start_time,\ | |
start_time + residue, \ | |
mono_labels.contexts[note_index]], strict=False) | |
elif phoneme_num_in_note == 2: | |
logging.debug("Current note consists of 'pau' 'br'") | |
br_duration = mono_labels.end_times[note_index+1] - mono_labels.start_times[note_index+1] | |
residue-= br_duration | |
new_mono_labels.append([start_time, \ | |
start_time + residue, \ | |
mono_labels.contexts[note_index]], strict=False) | |
else: | |
raise RuntimeError(f"The impossible phoneme_num_in_note: {phoneme_num_in_note}") | |
elif _is_consonant(mono_labels.contexts[note_index+pos]): | |
logging.debug("[:consonant:]") | |
phoneme_duration = mono_labels.end_times[note_index+pos] - mono_labels.start_times[note_index+pos] | |
new_mono_labels.append([new_mono_labels.end_times[-1], \ | |
new_mono_labels.end_times[-1] + phoneme_duration, \ | |
mono_labels.contexts[note_index+pos]], strict=False) | |
residue-= phoneme_duration | |
elif _is_vowel(mono_labels.contexts[note_index+pos]): | |
logging.debug("[:vowel:]") | |
if pos == phoneme_num_in_note - 1: | |
logging.debug("Current note ends with [:vowel:]") | |
new_mono_labels.append([new_mono_labels.end_times[-1], \ | |
new_mono_labels.end_times[-1] + residue, \ | |
mono_labels.contexts[note_index+pos]], strict=False) | |
else: | |
logging.debug("Current note ends with [:special:]") | |
special_duration = mono_labels.end_times[note_index+pos+1] - mono_labels.start_times[note_index+pos+1] | |
residue-=special_duration | |
new_mono_labels.append([new_mono_labels.end_times[-1], \ | |
new_mono_labels.end_times[-1] + residue, \ | |
mono_labels.contexts[note_index+pos]], strict=False) | |
elif _is_special(mono_labels.contexts[note_index+pos]): | |
logging.debug("[:special:]") | |
if pos != phoneme_num_in_note - 1: | |
raise RuntimeError(f"[:special:] does not located as the last phoneme.") | |
phoneme_duration = mono_labels.end_times[note_index+pos] - mono_labels.start_times[note_index+pos] | |
new_mono_labels.append([new_mono_labels.end_times[-1], \ | |
new_mono_labels.end_times[-1] + phoneme_duration, \ | |
mono_labels.contexts[note_index+pos]], strict=False) | |
else: | |
raise RuntimeError(f"Unknown phoneme: {mono_labels.contexts[note_index+pos]}") | |
logging.debug(f"int(mono_labels.end_times[-1]*scale): {int(mono_labels.end_times[-1]*scale)}") | |
logging.debug(f"new_mono_labels.end_times[-1]: {new_mono_labels.end_times[-1]}") | |
assert len(mono_labels) == len(new_mono_labels) | |
if abs(int(mono_labels.end_times[-1]*scale) - new_mono_labels.end_times[-1]) > tolerance: | |
raise RuntimeError(f"Cumulative error exceed the tolerance.") | |
for i in range(len(mono_labels)): | |
logging.debug(f"{mono_labels.start_times[i]} {mono_labels.end_times[i]} {mono_labels.contexts[i]} | {new_mono_labels.start_times[i]} {new_mono_labels.end_times[i]} {new_mono_labels.contexts[i]} | {mono_labels.end_times[i] - mono_labels.start_times[i]} {new_mono_labels.end_times[i] - new_mono_labels.start_times[i]}") | |
return new_mono_labels | |
def _convert_full_labels(full_labels, scale): | |
new_s = [] | |
new_e = [] | |
new_contexts = [] | |
for s, e, context in full_labels: | |
new_s.append(int(s * scale)) | |
new_e.append(int(e * scale)) | |
# Tempo: d5, e5, f5 | |
for id, pre, post in [("d5", "%", "\\|"), ("e5", "~", "!"), ("f5", "\\$", "\\$")]: | |
match = re.search(f"{pre}([0-9]+){post}", context) | |
# if not "xx" | |
if match is not None: | |
assert len(match.groups()) == 1 | |
num = match.group(0)[1:-1] | |
if len(num) > 0: | |
pre = pre.replace("\\", "") | |
post = post.replace("\\", "") | |
new_num = int(round(float(num) / scale)) | |
logging.debug(f"id: {id}, old_tempo: {num}, new_tempo: {new_num}") | |
context = context.replace( | |
match.group(0), f"{pre}{new_num}{post}", 1 | |
) | |
# Length in sec or msec: d7, e7, f7 | |
# e12/13, e20/21, e31/32, e37/e38, e43/44, e51/52 | |
for id, pre, post in [ | |
("d7", "&", ";"), | |
("e7", "@", "#"), | |
("f7", "\\+", "%"), | |
("e12", "\\|", "\\["), | |
("e13", "\\[", "&"), | |
("e20", "_", ";"), | |
("e21", ";", "\\$"), | |
("e31", "~", "="), | |
("e32", "=", "@"), | |
("e37", "#", "\\|"), | |
("e38", "\\|", "\\|"), | |
("e43", "\\+", "\\["), | |
("e44", "\\[", ";"), | |
("e51", "\\^", "@"), | |
("e52", "@", "\\["), | |
]: | |
match = re.search(f"{pre}([0-9]+){post}", context) | |
# if not "xx" | |
if match is not None: | |
assert len(match.groups()) == 1 | |
num = match.group(0)[1:-1] | |
if len(num) > 0: | |
pre = pre.replace("\\", "") | |
post = post.replace("\\", "") | |
# NOTE: ensure > 0 | |
new_num = max(int(float(num) * scale), 1) | |
logging.debug(f"id: {id}, old_length(by 0.01 sec): {num}, new_length(by 0.01 sec): {new_num}") | |
context = context.replace( | |
match.group(0), f"{pre}{new_num}{post}", 1 | |
) | |
new_contexts.append(context) | |
new_full_labels = hts.HTSLabelFile() | |
new_full_labels.start_times = new_s | |
new_full_labels.end_times = new_e | |
new_full_labels.contexts = new_contexts | |
assert len(full_labels) == len(new_full_labels) | |
assert int(full_labels.end_times[-1]*scale) == new_full_labels.end_times[-1] | |
return new_full_labels | |
def _convert_wav(wav, sr, mono_labels, new_mono_labels, scale, tolerance=5): | |
logging.debug(f"wav.shape: {wav.shape}") | |
time_map = [] | |
for idx in range(len(mono_labels)): | |
end_frame = int(mono_labels.end_times[idx] * 1e-7 * sr) | |
new_end_frame = int(new_mono_labels.end_times[idx] * 1e-7 * sr) | |
logging.debug(f"phoneme: {mono_labels.contexts[idx]}, end_frame: {end_frame}, new_end_frame: {new_end_frame}") | |
time_map.append([end_frame, new_end_frame]) | |
if time_map[-1][0] < wav.shape[0]: | |
logging.debug(f"time_map[-1][0]: {time_map[-1][0]} is smaller than wav.shape[0]: {wav.shape[0]}") | |
wav = wav[0:time_map[-1][0]] | |
elif time_map[-1][0] > wav.shape[0]: | |
logging.debug(f"time_map[-1][0]: {time_map[-1][0]} is bigger than wav.shape[0]: {wav.shape[0]}") | |
time_map[-1][0] = wav.shape[0] | |
time_map[-1][1] = int(wav.shape[0] * scale) | |
pyrb.__RUBBERBAND_UTIL = 'rubberband-r3' | |
new_wav = pyrb.timemap_stretch(wav, sr, time_map) | |
logging.debug(f"new_wav.shape: {new_wav.shape}") | |
return new_wav | |
if __name__ == "__main__": | |
args = get_parser().parse_args(sys.argv[1:]) | |
if args.debug: | |
logging.basicConfig(level=logging.DEBUG) | |
src_dir = args.src_dir | |
dest_dir = args.dest_dir | |
scale = args.scale | |
new_mono_lab_dir = join(dest_dir, "labels", "mono") | |
new_full_lab_dir = join(dest_dir, "labels", "full") | |
new_wav_dir = join(dest_dir, "wav") | |
os.makedirs(new_mono_lab_dir, exist_ok=True) | |
os.makedirs(new_full_lab_dir, exist_ok=True) | |
os.makedirs(new_wav_dir, exist_ok=True) | |
postfix = f"citsda_{str(scale).replace('.', '_')}" | |
logging.debug(f"postfix: {postfix}") | |
full_lab_files = sorted(glob(join(src_dir, "labels", "full", "*.lab"))) | |
song_list = [] | |
for full_lab_file in tqdm(full_lab_files): | |
logging.debug(f"full_lab_file: {full_lab_file}") | |
song_name = splitext(basename(full_lab_file))[0] | |
song_list.append(song_name) | |
full_labels = hts.load(full_lab_file) | |
assert _is_full_context(full_labels) | |
new_full_labels = _convert_full_labels(full_labels, scale) | |
note_indices = get_note_indices(full_labels) | |
mono_lab_file = join(src_dir, "labels", "mono", f"{song_name}.lab") | |
logging.debug(f"mono_lab_file: {mono_lab_file}") | |
mono_labels = hts.load(mono_lab_file) | |
assert not _is_full_context(mono_labels) | |
try: | |
new_mono_labels = _convert_mono_labels(mono_labels, note_indices, scale) | |
except RuntimeError as e: | |
print(f"ERROR: {song_name}: {e}") | |
continue | |
wav_file = join(src_dir, "wav", f"{song_name}.wav") | |
wav, sr = librosa.load(wav_file, sr=None) | |
logging.debug(f"wav.shape: {wav.shape}") | |
new_wav = _convert_wav(wav, sr, mono_labels, new_mono_labels, scale) | |
new_mono_lab_file = join(new_mono_lab_dir, f"{song_name}_{postfix}.lab") | |
with open(new_mono_lab_file, "w") as of: | |
of.write(str(new_mono_labels)) | |
new_full_lab_file = join(new_full_lab_dir, f"{song_name}_{postfix}.lab") | |
with open(new_full_lab_file, "w") as of: | |
of.write(str(new_full_labels)) | |
new_wav_file = join(new_wav_dir, f"{song_name}_{postfix}.wav") | |
logging.debug(f"new_wav.shape: {new_wav.shape}") | |
sf.write(new_wav_file, new_wav, sr, format="WAV") | |
print(song_list) | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment