Created
July 6, 2023 13:55
-
-
Save SolomidHero/8c4eec852a323cb8f43736a87d38bbd8 to your computer and use it in GitHub Desktop.
Visualize mel difference
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
wav_path = str(REPO_PATH / "tests/accent_test_sample/split/accent17/accent17_a17_14s.wav") | |
wav_16, sr16 = librosa.load(wav_path, sr=16000) | |
wav_22, sr22 = librosa.load(wav_path, sr=22000) | |
print(wav_22.shape, wav_16.shape) | |
# test1: sr=16khz, win=640, hop=160 | |
from new_data.preprocessing.batch_processors.spectrogram_creator import ( | |
SpectrogramCreator | |
) | |
from new_data.train_data_containers import Batch | |
mel_fn = SpectrogramCreator( | |
sample_rate=16000, | |
filter_length=640, | |
win_length=640, | |
hop_length=160, | |
device="cpu", | |
) | |
res16 = mel_fn(Batch( | |
audio=wav_16[None], | |
audio_len=torch.tensor([len(wav_16)]), | |
)) | |
# test2: sr=22khz, win=880, hop=220 | |
from new_data.preprocessing.batch_processors.spectrogram_creator import ( | |
SpectrogramCreator | |
) | |
from new_data.train_data_containers import Batch | |
mel_fn = SpectrogramCreator( | |
sample_rate=22000, | |
filter_length=880, | |
win_length=880, | |
hop_length=220, | |
device="cpu", | |
) | |
res22 = mel_fn(Batch( | |
audio=wav_22[None], | |
audio_len=torch.tensor([len(wav_22)]), | |
)) | |
# visualize | |
print(res16['spectrogram'].shape) | |
mel16, mel22 = (b['spectrogram'][:, :, :78] for b in (res16, res22)) | |
print(mel16.shape, mel16.min(), mel16.max()) | |
print(mel22.shape, mel22.min(), mel22.max()) | |
# I perform linear transform first | |
mel22 = (mel22 - mel22.min()) / (mel22.max() - mel22.min()) # to [0, 1] | |
mel22 = mel22 * (mel16.max() - mel16.min()) + mel16.min() | |
from utils.plot import plot_output, calc_shift, plot_shift_errors, plot_error | |
plot_output(mel16, mel22, f=None) | |
plt.show() | |
shifts = list(range(60)) | |
shift, diffs, errors = calc_shift(mel16, mel22, shifts) | |
print("calced shift = ", shift) | |
print("with MAE=", errors[shift].item()) | |
plot_shift_errors(shifts, errors, f=None) | |
plt.show() | |
plot_error(mel16, mel22, shift, f=None) | |
plt.show() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment