Created
January 15, 2024 10:42
-
-
Save libratiger/5ad9b7baac2dc43ec63e1166d54922c9 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def log_mel_spectrogram( | |
audio: Union[str, np.ndarray, torch.Tensor], | |
n_mels: int = 80, | |
padding: int = 0, | |
device: Optional[Union[str, torch.device]] = None, | |
): | |
""" | |
Compute the log-Mel spectrogram of | |
Parameters | |
---------- | |
audio: Union[str, np.ndarray, torch.Tensor], shape = (*) | |
The path to audio or either a NumPy array or Tensor containing the audio waveform in 16 kHz | |
n_mels: int | |
The number of Mel-frequency filters, only 80 is supported | |
padding: int | |
Number of zero samples to pad to the right | |
device: Optional[Union[str, torch.device]] | |
If given, the audio tensor is moved to this device before STFT | |
Returns | |
------- | |
torch.Tensor, shape = (80, n_frames) | |
A Tensor that contains the Mel spectrogram | |
""" | |
if not torch.is_tensor(audio): | |
if isinstance(audio, str): | |
audio = load_audio(audio) | |
audio = torch.from_numpy(audio) | |
if device is not None: | |
audio = audio.to(device) | |
if padding > 0: | |
audio = F.pad(audio, (0, padding)) | |
window = torch.hann_window(N_FFT).to(audio.device) | |
stft = torch.stft(audio, N_FFT, HOP_LENGTH, window=window, return_complex=True) | |
magnitudes = stft[..., :-1].abs() ** 2 | |
filters = mel_filters(audio.device, n_mels) | |
mel_spec = filters @ magnitudes | |
log_spec = torch.clamp(mel_spec, min=1e-10).log10() | |
log_spec = torch.maximum(log_spec, log_spec.max() - 8.0) | |
log_spec = (log_spec + 4.0) / 4.0 | |
return log_spec | |
def load_audio(file: str, sr: int = SAMPLE_RATE): | |
""" | |
Open an audio file and read as mono waveform, resampling as necessary | |
Parameters | |
---------- | |
file: str | |
The audio file to open | |
sr: int | |
The sample rate to resample the audio if necessary | |
Returns | |
------- | |
A NumPy array containing the audio waveform, in float32 dtype. | |
""" | |
# This launches a subprocess to decode audio while down-mixing | |
# and resampling as necessary. Requires the ffmpeg CLI in PATH. | |
# fmt: off | |
cmd = [ | |
"ffmpeg", | |
"-nostdin", | |
"-threads", "0", | |
"-i", file, | |
"-f", "s16le", | |
"-ac", "1", | |
"-acodec", "pcm_s16le", | |
"-ar", str(sr), | |
"-" | |
] | |
# fmt: on | |
try: | |
out = run(cmd, capture_output=True, check=True).stdout | |
except CalledProcessError as e: | |
raise RuntimeError(f"Failed to load audio: {e.stderr.decode()}") from e | |
return np.frombuffer(out, np.int16).flatten().astype(np.float32) / 32768.0 |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment