Skip to content

Instantly share code, notes, and snippets.

@zjlww
Created December 7, 2024 01:39
Show Gist options
  • Save zjlww/64f254cb8ee553fcfa1408dffd843484 to your computer and use it in GitHub Desktop.
Save zjlww/64f254cb8ee553fcfa1408dffd843484 to your computer and use it in GitHub Desktop.
Stripped AudioCodecModel from NeMo @ bde672e
from typing import Tuple
import torch
from torch import nn, Tensor
import torch.nn.functional as F
from einops import rearrange
from .modules import HiFiGANEncoder, HiFiGANDecoder, GroupFiniteScalarQuantizer
class AudioCodecModel(nn.Module):
def __init__(self):
super().__init__()
self.sample_rate = 22050
self.samples_per_frame = 1024
self.audio_encoder = HiFiGANEncoder(
down_sample_rates=(2, 2, 4, 8, 8),
encoded_dim=32,
base_channels=48,
resblock_dilation_sizes=(1,),
)
self.audio_decoder = HiFiGANDecoder(
up_sample_rates=(8, 8, 4, 2, 2),
input_dim=32,
base_channels=1024,
)
self.vector_quantizer = GroupFiniteScalarQuantizer(8, [8, 7, 6, 6])
def encode_audio(self, audio: Tensor, audio_len: Tensor) -> Tuple[Tensor, Tensor]:
"""Apply encoder on the input audio signal. Input will be padded with zeros so
the last frame has full `self.samples_per_frame` samples.
Args:
audio (Tensor): [B, T_audio].
audio_len (LongTensor): [B].
Returns:
encoded (Tensor): [B, D, T_encoded].
encoded_len (LongTensor): [B].
"""
audio, audio_len = self.pad_audio(audio, audio_len)
encoded, encoded_len = self.audio_encoder(audio=audio, audio_len=audio_len)
return encoded, encoded_len
def decode_audio(self, inputs: Tensor, input_len: Tensor) -> Tuple[Tensor, Tensor]:
"""Apply decoder on the input. Note that the input is a non-quantized encoder output or a dequantized representation.
Args:
inputs (Tensor): [B, D, T_encoded].
input_len (LongTensor): [B]. Valid length for each example in the batch
Returns:
audio (Tensor): [B, T_audio].
Decoded output `audio` in the time domain and its length in number of samples `audio_len`.
audio_len (LongTensor): [B].
Note that `audio_len` will be a multiple of `self.samples_per_frame`.
"""
audio, audio_len = self.audio_decoder(inputs=inputs, input_len=input_len)
return audio, audio_len
def quantize(self, encoded: Tensor, encoded_len: Tensor) -> Tensor:
"""Quantize the continuous encoded representation into a discrete
representation for each frame.
Args:
encoded (Tensor): [B, D, T_encoded]. Encoded signal representation.
encoded_len (Tensor): [B]. Valid length of the encoded representation in frames.
Returns:
tokens (Tensor): [B, C, T_encoded]. A tensor of tokens for each codebook for each frame.
"""
if not self.vector_quantizer:
raise ValueError("Cannot quantize without quantizer")
# vector quantizer is returning [C, B, T], where C is the number of codebooks
tokens = self.vector_quantizer.encode(inputs=encoded, input_len=encoded_len)
# use batch first for the output
tokens = rearrange(tokens, "C B T -> B C T")
return tokens
def dequantize(self, tokens: Tensor, tokens_len: Tensor) -> Tensor:
"""Convert the discrete tokens into a continuous encoded representation.
Args:
tokens (Tensor): [B, C, T_encoded]. Discrete tokens for each codebook for each time frame.
tokens_len (Tensor): [B]. Valid length of each example in the batch.
Returns:
dequantized (Tensor): [B, D, T_encoded]. Continuous encoded representation of the discrete input representation.
"""
if not self.vector_quantizer:
raise ValueError("Cannot dequantize without quantizer")
# vector quantizer is using [C, B, T], where C is the number of codebooks
tokens = rearrange(tokens, "B C T -> C B T")
dequantized = self.vector_quantizer.decode(indices=tokens, input_len=tokens_len)
return dequantized
def encode(self, audio: Tensor, audio_len: Tensor) -> Tuple[Tensor, Tensor]:
"""Convert input time-domain audio signal into a discrete representation (tokens).
Args:
audio (Tensor): input time-domain signal, shape `(B, T_audio)`
audio_len (Tensor): valid length for each example in the batch, shape `(B,)`
Returns:
tokens (Tensor): Tokens for each codebook for each frame, shape `(B, C, T_encoded)`
encoded_len (Tensor): Corresponding valid lengths, shape `(B,)`
"""
# Apply encoder to obtain a continuous vector for each frame
encoded, encoded_len = self.encode_audio(audio=audio, audio_len=audio_len)
# Apply quantizer to obtain discrete representation per frame
tokens = self.quantize(encoded=encoded, encoded_len=encoded_len)
return tokens, encoded_len
def decode(self, tokens: Tensor, tokens_len: Tensor) -> Tuple[Tensor, Tensor]:
"""Convert discrete tokens into a continuous time-domain signal.
Args:
tokens (Tensor): [B, C, T_encoded]. Discrete tokens for each codebook for each time frame.
tokens_len (Tensor): [B]. Valid lengths for each example in the batch.
Returns:
audio (Tensor): [B, T_audio]. Decoded output `audio` in the time domain.
audio_len (Tensor): [B]. Length of the decoded audio in number of samples.
Note that `audio_len` will be a multiple of `self.samples_per_frame`.
"""
# Convert a discrete representation to a dequantized vector for each frame
dequantized = self.dequantize(tokens=tokens, tokens_len=tokens_len)
# Apply decoder to obtain time-domain audio for each frame
audio, audio_len = self.decode_audio(inputs=dequantized, input_len=tokens_len)
return audio, audio_len
def forward(self, audio: Tensor, audio_len: Tensor) -> Tuple[Tensor, Tensor]:
"""Apply encoder, quantizer, decoder on the input time-domain signal.
Args:
audio (Tensor): input time-domain signal, shape `(B, T_audio)`
audio_len (Tensor): valid length for each example in the batch, shape `(B,)`
Returns:
output_audio (Tensor): Reconstructed time-domain signal, shape `(B, T_audio)`
output_audio_len (Tensor): Length of the reconstructed audio in number of samples, shape `(B,)`
"""
encoded, encoded_len = self.encode_audio(audio=audio, audio_len=audio_len)
if self.vector_quantizer:
# quantize to discrete tokens
tokens = self.quantize(encoded=encoded, encoded_len=encoded_len)
# decode tokens to audio
output_audio, output_audio_len = self.decode(
tokens=tokens, tokens_len=encoded_len
)
else:
# no quantization, directly decode to audio
output_audio, output_audio_len = self.decode_audio(
inputs=encoded, input_len=encoded_len
)
return output_audio, output_audio_len
def pad_audio(self, audio: Tensor, audio_len: Tensor) -> Tuple[Tensor, Tensor]:
"""Zero pad the end of the audio so that we do not have a partial end frame.
The output will be zero-padded to have an integer number of frames of
length `self.samples_per_frame`.
Args:
audio (Tensor): input time-domain signal, shape `(B, T_audio)`
audio_len (Tensor): valid length for each example in the batch, shape `(B,)`
Returns:
padded_audio (Tensor): Padded time-domain signal, shape `(B, T_padded)`
padded_len (Tensor): Length of the padded audio, shape `(B,)`
"""
padded_len = (
self.samples_per_frame
* torch.ceil(audio_len / self.samples_per_frame).int()
)
max_len = padded_len.max().item()
num_padding = max_len - audio.shape[1]
padded_audio = F.pad(audio, (0, num_padding))
return padded_audio, padded_len
from abc import ABC, abstractmethod
from typing import Iterable, List, Optional, Tuple, Union
import numpy as np
import math
import torch
from torch import Tensor
import random
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
import logging
from .utils import ClampActivation, HalfSnake, Snake, mask_sequence_tensor
CONSTANT = 1e-5
def get_padding(kernel_size: int, dilation: int = 1) -> int:
return (kernel_size * dilation - dilation) // 2
def get_padding_2d(
kernel_size: Tuple[int, int], dilation: Tuple[int, int]
) -> Tuple[int, int]:
paddings = (
get_padding(kernel_size[0], dilation[0]),
get_padding(kernel_size[1], dilation[1]),
)
return paddings
def get_down_sample_padding(kernel_size: int, stride: int) -> int:
return (kernel_size - stride + 1) // 2
def get_up_sample_padding(kernel_size: int, stride: int) -> Tuple[int, int]:
output_padding = (kernel_size - stride) % 2
padding = (kernel_size - stride + 1) // 2
return padding, output_padding
class CodecActivation(nn.Module):
"""
Choose between activation based on the input parameter.
Args:
activation: Name of activation to use. Valid options are "elu" (default), "lrelu", and "snake".
channels: Input dimension.
"""
def __init__(self, activation: str = "elu", channels: int = 1):
super().__init__()
activation = activation.lower()
if activation == "elu":
self.activation = nn.ELU()
elif activation == "lrelu":
self.activation = torch.nn.LeakyReLU()
elif activation == "snake":
self.activation = Snake(channels)
elif activation == "half_snake":
self.activation = HalfSnake(channels)
else:
raise ValueError(f"Unknown activation {activation}")
def forward(self, x):
return self.activation(x)
class Conv1dNorm(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: int,
stride: int = 1,
dilation: int = 1,
padding: Optional[int] = None,
):
super().__init__()
if not padding:
padding = get_padding(kernel_size=kernel_size, dilation=dilation)
conv = nn.Conv1d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding,
dilation=dilation,
padding_mode="reflect",
)
self.conv = nn.utils.weight_norm(conv)
def remove_weight_norm(self):
nn.utils.remove_weight_norm(self.conv)
def forward(self, inputs, input_len):
out = self.conv(inputs)
out = mask_sequence_tensor(out, input_len)
return out
class ConvTranspose1dNorm(nn.Module):
def __init__(
self, in_channels: int, out_channels: int, kernel_size: int, stride: int = 1
):
super().__init__()
padding, output_padding = get_up_sample_padding(kernel_size, stride)
conv = nn.ConvTranspose1d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding,
output_padding=output_padding,
padding_mode="zeros",
)
self.conv = nn.utils.weight_norm(conv)
def remove_weight_norm(self):
nn.utils.remove_weight_norm(self.conv)
def forward(self, inputs, input_len):
out = self.conv(inputs)
out = mask_sequence_tensor(out, input_len)
return out
class Conv2dNorm(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: Tuple[int, int],
stride: Tuple[int, int] = (1, 1),
dilation: Tuple[int, int] = (1, 1),
):
super().__init__()
assert len(kernel_size) == len(dilation)
padding = get_padding_2d(kernel_size, dilation)
conv = nn.Conv2d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
dilation=dilation,
padding=padding,
padding_mode="reflect",
)
self.conv = nn.utils.weight_norm(conv)
def remove_weight_norm(self):
nn.utils.remove_weight_norm(self.conv)
def forward(self, inputs):
return self.conv(inputs)
class PeriodDiscriminator(nn.Module):
"""
Period discriminator introduced in HiFi-GAN https://arxiv.org/abs/2010.05646 which attempts to
discriminate phase information by looking at equally spaced audio samples.
Args:
period: Spacing between audio sample inputs.
lrelu_slope: Slope to use for activation. Leaky relu with slope of 0.1 or 0.2 is recommended for the
stability of the feature matching loss.
"""
def __init__(self, period, lrelu_slope=0.1):
super().__init__()
self.period = period
self.activation = nn.LeakyReLU(lrelu_slope)
self.conv_layers = nn.ModuleList(
[
Conv2dNorm(1, 32, kernel_size=(5, 1), stride=(3, 1)),
Conv2dNorm(32, 128, kernel_size=(5, 1), stride=(3, 1)),
Conv2dNorm(128, 512, kernel_size=(5, 1), stride=(3, 1)),
Conv2dNorm(512, 1024, kernel_size=(5, 1), stride=(3, 1)),
Conv2dNorm(1024, 1024, kernel_size=(5, 1), stride=(1, 1)),
]
)
self.conv_post = Conv2dNorm(1024, 1, kernel_size=(3, 1))
def forward(self, audio):
batch_size, time = audio.shape
out = rearrange(audio, "B T -> B 1 T")
# Pad audio so that it is divisible by the period
if time % self.period != 0:
n_pad = self.period - (time % self.period)
out = F.pad(out, (0, n_pad), "reflect")
time = time + n_pad
# [batch, 1, (time / period), period]
out = out.view(batch_size, 1, time // self.period, self.period)
fmap = []
for conv in self.conv_layers:
# [batch, filters, (time / period / stride), period]
out = conv(inputs=out)
out = self.activation(out)
fmap.append(out)
# [batch, 1, (time / period / strides), period]
score = self.conv_post(inputs=out)
fmap.append(score)
score = rearrange(score, "B 1 T C -> B C T")
return score, fmap
class MultiPeriodDiscriminator(nn.Module):
"""
Wrapper class to aggregate results of multiple period discriminators.
The periods are expected to be increasing prime numbers in order to maximize coverage and minimize overlap
"""
def __init__(self, periods: Iterable[int] = (2, 3, 5, 7, 11), lrelu_slope=0.1):
super().__init__()
self.discriminators = nn.ModuleList(
[
PeriodDiscriminator(period=period, lrelu_slope=lrelu_slope)
for period in periods
]
)
def forward(self, audio_real, audio_gen):
scores_real = []
scores_gen = []
fmaps_real = []
fmaps_gen = []
for discriminator in self.discriminators:
score_real, fmap_real = discriminator(audio=audio_real)
score_gen, fmap_gen = discriminator(audio=audio_gen)
scores_real.append(score_real)
fmaps_real.append(fmap_real)
scores_gen.append(score_gen)
fmaps_gen.append(fmap_gen)
return scores_real, scores_gen, fmaps_real, fmaps_gen
class DiscriminatorSTFT(nn.Module):
"""
Discriminator network from EnCodec for Complex STFT input, but without dilations.
Args:
filters: number of filters to use in Conv2d layers
lrelu_slope: Slope to use for activations. Leaky relu with slope of 0.1 or 0.2 is recommended for the
stability of the feature matching loss
"""
def __init__(self, filters: int = 32, lrelu_slope: float = 0.1):
super().__init__()
self.activation = nn.LeakyReLU(lrelu_slope)
self.conv_layers = nn.ModuleList(
[
Conv2dNorm(2, filters, kernel_size=(3, 9)),
Conv2dNorm(filters, filters, kernel_size=(3, 9), stride=(1, 2)),
Conv2dNorm(filters, filters, kernel_size=(3, 9), stride=(1, 2)),
Conv2dNorm(filters, filters, kernel_size=(3, 9), stride=(1, 2)),
Conv2dNorm(filters, filters, kernel_size=(3, 3)),
]
)
self.conv_post = Conv2dNorm(filters, 1, kernel_size=(3, 3))
def forward(self, spec):
fmap = []
# [batch, 2, T_spec, fft]
out = spec
for conv in self.conv_layers:
# [batch, filters, T_spec, fft // strides]
out = conv(inputs=out)
out = self.activation(out)
fmap.append(out)
# [batch, 1, T_spec, fft // 8]
scores = self.conv_post(inputs=out)
fmap.append(scores)
scores = rearrange(scores, "B 1 T C -> B C T")
return scores, fmap
class MultiBandDiscriminatorSTFT(nn.Module):
"""
Multi-band STFT discriminator proposed in DAC (https://arxiv.org/abs/2306.06546).
Computes the complex STFT for a given resolution and splits it into sub-bands,
which are given to separate discriminator networks.
Args:
resolution: STFT resolution, provided as a tuple of 3 integers ordered (num_fft, hop_length, window_length)
stft_bands: List of tuples, with each tuple having 2 float values (band_start, band_end).
The floats are in the range [0, 1] representing the fraction of all stft bands.
For example for n_fft=1024, the stft output has 513 dimensions.
For band input [(0, 0.25), (0.25, 1.0)] it would use stft dimensions [0 through 127] and [128 through 512].
"""
def __init__(
self, resolution: Tuple[int, ...], stft_bands: Iterable[Tuple[int, int]]
):
super().__init__()
self.n_fft, self.hop_length, self.win_length = resolution
self.register_buffer(
"window", torch.hann_window(self.win_length, periodic=False)
)
self.discriminators = nn.ModuleList([DiscriminatorSTFT() for _ in stft_bands])
n_stft = self.n_fft // 2 + 1
self.stft_bands = [
(int(band[0] * n_stft), int(band[1] * n_stft)) for band in stft_bands
]
def compute_stft(self, audio):
# [B, fft, T_spec]
fft = torch.stft(
audio,
n_fft=self.n_fft,
hop_length=self.hop_length,
win_length=self.win_length,
window=self.window,
normalized=True,
center=True,
return_complex=True,
)
fft = rearrange(fft, "B fft T -> B T fft")
# [batch, 2, T_spec, fft]
out = torch.stack([fft.real, fft.imag], dim=1)
return out
def forward(self, audio):
scores_list = []
fmap_list = []
spec = self.compute_stft(audio)
for band, disc in zip(self.stft_bands, self.discriminators):
spec_band = spec[:, :, :, band[0] : band[1]]
score, fmap = disc(spec=spec_band)
scores_list.append(score)
fmap_list.append(fmap)
return scores_list, fmap_list
class MultiResolutionDiscriminatorSTFT(nn.Module):
"""
Multi-resolution discriminator which creates a multi-band discriminator for each input resolution.
Args:
resolutions: List of STFT resolutions, each resolution provided as a tuple of 3 integers ordered
(num_fft, hop_length, window_length)
stft_bands: List of tuples, with each tuple having 2 float values (band_start, band_end).
The floats are in the range [0, 1] representing the fraction of all stft bands.
For example for n_fft=1024, the stft output has 513 dimensions.
For band input [(0, 0.25), (0.25, 1.0)] it would use stft dimensions [0 through 127] and [128 through 512].
"""
def __init__(
self, resolutions: Iterable[Tuple[int]], stft_bands: Iterable[Tuple[int, int]]
):
super().__init__()
self.discriminators = nn.ModuleList(
[
MultiBandDiscriminatorSTFT(resolution=resolution, stft_bands=stft_bands)
for resolution in resolutions
]
)
def forward(self, audio_real, audio_gen):
scores_real = []
scores_gen = []
fmaps_real = []
fmaps_gen = []
for disc in self.discriminators:
score_real_i, fmap_real_i = disc(audio=audio_real)
scores_real = scores_real + score_real_i
fmaps_real = fmaps_real + fmap_real_i
score_gen_i, fmap_gen_i = disc(audio=audio_gen)
scores_gen = scores_gen + score_gen_i
fmaps_gen = fmaps_gen + fmap_gen_i
return scores_real, scores_gen, fmaps_real, fmaps_gen
class Discriminator(nn.Module):
"""
Wrapper class which takes a list of discriminators and aggregates the results across them.
"""
def __init__(self, discriminators: Iterable[nn.Module]):
super().__init__()
self.discriminators = nn.ModuleList(discriminators)
def forward(self, audio_real, audio_gen):
scores_real = []
scores_gen = []
fmaps_real = []
fmaps_gen = []
for discriminator in self.discriminators:
score_real, score_gen, fmap_real, fmap_gen = discriminator(
audio_real=audio_real, audio_gen=audio_gen
)
scores_real += score_real
fmaps_real += fmap_real
scores_gen += score_gen
fmaps_gen += fmap_gen
return scores_real, scores_gen, fmaps_real, fmaps_gen
class VectorQuantizerBase(nn.Module, ABC):
@abstractmethod
def forward(self, inputs: Tensor, input_len: Tensor) -> Tuple[Tensor, Tensor]:
pass
@abstractmethod
def encode(self, inputs: Tensor, input_len: Tensor) -> Tensor:
pass
@abstractmethod
def decode(self, indices: Tensor, input_len: Tensor) -> Tensor:
pass
class FiniteScalarQuantizer(VectorQuantizerBase):
"""This quantizer is based on the Finite Scalar Quantization (FSQ) method.
It quantizes each element of the input vector independently into a number of levels.
Args:
num_levels: number of levels for each dimension/element of the input vector
eps: small regularization constant for scaling
References:
Mentzer et al., Finite Scalar Quantization: VQ-VAE Made Simple (https://arxiv.org/abs/2309.15505v1)
"""
def __init__(self, num_levels: List[int], eps: float = 1e-3):
super().__init__()
# index base per dimension of the input vector
# this is used to convert between per-dimension indices and a codebook token index
dim_base_index = torch.cumprod(
torch.tensor([1] + num_levels[:-1]), dim=0, dtype=torch.int32
)
dim_base_index = rearrange(dim_base_index, "D -> 1 D 1")
self.register_buffer("dim_base_index", dim_base_index)
# Register the number of levels for each dimension
num_levels = torch.tensor(num_levels, dtype=torch.int32)
num_levels = rearrange(num_levels, "D -> 1 D 1")
self.register_buffer("num_levels", num_levels)
# Regularization
self.eps = eps
logging.debug("Initializing %s with", self.__class__.__name__)
logging.debug("\tdim: %s", self.dim)
logging.debug("\tnum_levels: %s", self.num_levels)
logging.debug("\tcodebook_size: %s", self.codebook_size)
logging.debug("\teps: %s", self.eps)
@property
def codebook_size(self):
"""Returns the size of the corresponding codebook."""
return self.num_levels.prod().item()
@property
def dim(self):
"""Returns the dimension of the input vector."""
return self.num_levels.numel()
@property
def codebook_dim(self):
"""Returns the dimension of the input vector.
Keeping for compatiblitiy with the original RVQ implementation.
"""
return self.dim
@property
def codes(self):
"""Returns the codebooks entries.
Note that the codebook entries are implicitly defined by the number of levels.
"""
indices = torch.arange(self.codebook_size)
# [D, B, T]
indices = rearrange(indices, "B -> 1 B 1")
# [B, D, T]
codes = self.decode(indices=indices, input_len=None)
# Remove the time dimension
codes = codes.squeeze(-1)
return codes
@property
def codebook(self):
"""Returns the codebooks entries.
See self.codes for more details.
"""
return self.codes
@staticmethod
def round(inputs: Tensor, input_len: Tensor) -> Tensor:
"""Round the input tensor to nearest integer
and use a straight-through estimator for the gradient.
"""
inputs_rounded = torch.round(inputs)
return inputs + (inputs_rounded - inputs).detach()
def compress(self, inputs: Tensor, input_len: Tensor) -> Tensor:
"""Apply compression to the input, to limit to values."""
output_scale = (self.num_levels - 1) / 2
# scale down a bit to avoid rounding issues
output_scale = output_scale * (1 - self.eps)
# offset for even number of levels
output_offset = torch.where(self.num_levels % 2 == 0, 0.5, 0)
# shift for even number of levels
input_shift = (output_offset / output_scale).tan()
# compressed output
output = output_scale * (inputs + input_shift).tanh() - output_offset
return output
def inputs_to_codes(self, inputs: Tensor, input_len: Tensor) -> Tensor:
# apply compression
compressed = self.compress(inputs=inputs, input_len=input_len)
# apply rounding to nearest integer
codes = self.round(inputs=compressed, input_len=input_len)
# normalize to [-1, 1]
scale = self.num_levels // 2
codes = codes / scale
return codes
def codes_to_nonnegative(self, codes: Tensor) -> Tensor:
"""Convert values centered arouund zero to nonnegative values."""
scale = offset = self.num_levels // 2
return scale * codes + offset
def nonnegative_to_codes(self, codes_nonnegative: Tensor) -> Tensor:
"""Convert nonnegative values to values centered arouund zero."""
scale = offset = self.num_levels // 2
return (codes_nonnegative - offset) / scale
def codes_to_indices(self, codes: Tensor) -> Tensor:
"""Converts a code vector to a single index."""
if codes.size(1) != self.dim:
raise RuntimeError(
f"Input code dimension {codes.size(1)} not matching the expected dimension {self.dim}, input codes shape {codes.shape}"
)
# convert code vectors to nonnegative values
indices = self.codes_to_nonnegative(codes)
# convert one nonnegative index per dimension to a single index per code vector
indices = torch.sum(indices * self.dim_base_index, dim=1)
return indices.to(torch.int32)
# Implementation of VectorQuantiserBase API
def forward(
self, inputs: Tensor, input_len: Optional[Tensor] = None
) -> Tuple[Tensor, Tensor]:
if inputs.size(1) != self.dim:
raise RuntimeError(
f"Input dimension {inputs.size(1)} not matching the expected dimension {self.dim}, inputs shape {inputs.shape}"
)
dequantized = self.inputs_to_codes(inputs=inputs, input_len=input_len)
indices = self.codes_to_indices(codes=dequantized)
if input_len is not None:
# apply masking
dequantized = mask_sequence_tensor(dequantized, input_len)
indices = mask_sequence_tensor(indices, input_len)
# only 1 codebook, but return in [D, B, T] format to match RVQ API
indices = indices.unsqueeze(0)
return dequantized, indices
def encode(self, inputs: Tensor, input_len: Optional[Tensor] = None) -> Tensor:
"""Convert a continuous code vector to a single index."""
_, indices = self(inputs=inputs, input_len=input_len)
return indices
def decode(self, indices: Tensor, input_len: Optional[Tensor] = None) -> Tensor:
"""Convert a single index to a continuous code vector."""
if indices.size(0) > 1:
# codebook dimension used for compatibility with RVQ
raise ValueError(
f"Expected a single codebook, got {indices.size(0)} codebooks for indices with shape {indices.shape}."
)
indices = rearrange(indices, "D B T -> B D T")
# convert a single index to nonnegative index per-dimension
codes_nonnegative = (indices // self.dim_base_index) % self.num_levels
# convert nonnegative codes to codes (centered around zero)
dequantized = self.nonnegative_to_codes(codes_nonnegative)
if input_len is not None:
# apply masking
dequantized = mask_sequence_tensor(dequantized, input_len)
return dequantized
class GroupFiniteScalarQuantizer(VectorQuantizerBase):
"""Split the input vector into groups and apply FSQ on each group separately.
This class is for convenience. Since FSQ is applied on each group separately,
groups can be defined arbitrarily by splitting the input vector. However, this
class makes it easy to construct several groups with the same quantization num_levels.
Args:
num_groups: number of groups to split the input into, each group will be quantized separately using num_codebooks//num_groups codebooks
codebook_dim: embedding dimension, will be split into num_groups
**kwargs: parameters of FiniteScalarQuantizer
References:
Yang et al, HiFi-Codec: Group-residual Vector quantization for High Fidelity Audio Codec, 2023 (http://arxiv.org/abs/2305.02765).
"""
def __init__(self, num_groups: int, num_levels_per_group: List[int], **kwargs):
super().__init__()
self.num_groups = num_groups
self.codebook_dim_per_group = len(num_levels_per_group)
# Initialize FSQ for each group
self.fsqs = torch.nn.ModuleList(
[
FiniteScalarQuantizer(num_levels=num_levels_per_group, **kwargs)
for _ in range(self.num_groups)
]
)
logging.debug("Initialized %s with", self.__class__.__name__)
logging.debug("\tnum_groups: %d", self.num_groups)
logging.debug("\tcodebook_dim: %d", self.codebook_dim)
logging.debug("\tnum_levels_per_group: %s", num_levels_per_group)
logging.debug("\tcodebook_dim_per_group: %d", self.codebook_dim_per_group)
@property
def codebook_dim(self):
"""Input vector dimension."""
return self.codebook_dim_per_group * self.num_groups
@property
def codebook_size_per_group(self):
"""Returns the size of the implicit codebook for each group."""
return self.fsqs[0].codebook_size
@property
def codebook_size(self):
"""Returns the size of the implicit codebook."""
return self.codebook_size_per_group**self.num_groups
def forward(self, inputs, input_len):
"""Quantize each group separately, then concatenate the results."""
inputs_grouped = inputs.chunk(self.num_groups, dim=1)
dequantized, indices = [], []
for in_group, fsq_group in zip(inputs_grouped, self.fsqs):
dequantized_group, indices_group = fsq_group(
inputs=in_group, input_len=input_len
)
dequantized.append(dequantized_group)
indices.append(indices_group)
# concatenate along the feature dimension
dequantized = torch.cat(dequantized, dim=1)
# concatente along the codebook dimension
indices = torch.cat(indices, dim=0)
return dequantized, indices
def encode(self, inputs: Tensor, input_len: Tensor) -> Tensor:
"""Input is split into groups, each group is encoded separately, then the results are concatenated."""
inputs_grouped = inputs.chunk(self.num_groups, dim=1)
indices = []
for in_group, fsq_group in zip(inputs_grouped, self.fsqs):
indices_group = fsq_group.encode(inputs=in_group, input_len=input_len)
indices.append(indices_group)
# concatenate along the codebook dimension
indices = torch.cat(indices, dim=0)
return indices
def decode(self, indices: Tensor, input_len: Tensor) -> Tensor:
"""Input indices are split into groups, each group is decoded separately, then the results are concatenated."""
indices_grouped = indices.chunk(self.num_groups, dim=0)
dequantized = []
for indices_group, fsq_group in zip(indices_grouped, self.fsqs):
dequantized_group = fsq_group.decode(
indices=indices_group, input_len=input_len
)
dequantized.append(dequantized_group)
# concatenate along the feature dimension
dequantized = torch.cat(dequantized, dim=1)
return dequantized
class ResidualBlock(nn.Module):
"""
The residual block structure defined by the HiFi-GAN V1 and V2 configurations.
Args:
channels: Input dimension.
filters: Number of channels in the residual convolutions.
kernel_size: Kernel size of the residual convolutions.
dilation: Dilation of the residual convolutions.
dropout_rate: Dropout to apply to residuals.
activation: Activation to apply in between residual convolutions.
"""
def __init__(
self,
channels: int,
filters: int,
kernel_size: int = 3,
dilation: int = 1,
dropout_rate: float = 0.0,
activation: str = "lrelu",
):
super(ResidualBlock, self).__init__()
self.input_activation = CodecActivation(
activation=activation, channels=channels
)
self.skip_activation = CodecActivation(activation=activation, channels=filters)
self.dropout = torch.nn.Dropout(dropout_rate)
self.input_conv = Conv1dNorm(
in_channels=channels,
out_channels=filters,
kernel_size=kernel_size,
dilation=dilation,
)
self.skip_conv = Conv1dNorm(
in_channels=filters, out_channels=channels, kernel_size=kernel_size
)
def remove_weight_norm(self):
self.input_conv.remove_weight_norm()
self.skip_conv.remove_weight_norm()
def forward(self, inputs, input_len):
conv_input = self.input_activation(inputs)
skip_input = self.input_conv(inputs=conv_input, input_len=input_len)
skip_input = self.skip_activation(skip_input)
res = self.skip_conv(inputs=skip_input, input_len=input_len)
res = self.dropout(res)
out = inputs + res
return out
class HiFiGANResBlock(nn.Module):
"""
Residual block wrapper for HiFi-GAN which creates a block for multiple dilations.
Args:
channels: Input dimension.
kernel_size: Kernel size of the residual blocks.
dilations: List of dilations. One residual block will be created for each dilation in the list.
activation: Activation for the residual blocks.
"""
def __init__(
self, channels: int, kernel_size: int, dilations: Iterable[int], activation: str
):
super().__init__()
self.res_blocks = nn.ModuleList(
[
ResidualBlock(
channels=channels,
filters=channels,
kernel_size=kernel_size,
dilation=dilation,
activation=activation,
)
for dilation in dilations
]
)
def remove_weight_norm(self):
for res_block in self.res_blocks:
res_block.remove_weight_norm()
def forward(self, inputs, input_len):
out = inputs
for res_block in self.res_blocks:
out = res_block(inputs=out, input_len=input_len)
return out
class HiFiGANResLayer(nn.Module):
"""
Residual block wrapper for HiFi-GAN which creates a block for multiple kernel sizes and dilations.
One residual block is created for each combination of kernel size and dilation.
Args:
channels: Input dimension.
kernel_sizes: List of kernel sizes.
dilations: List of dilations.
activation: Activation for the residual layers.
"""
def __init__(
self,
channels: int,
kernel_sizes: Iterable[int],
dilations: Iterable[int],
activation: str,
):
super().__init__()
self.res_blocks = nn.ModuleList(
[
HiFiGANResBlock(
channels=channels,
kernel_size=kernel_size,
dilations=dilations,
activation=activation,
)
for kernel_size in kernel_sizes
]
)
def remove_weight_norm(self):
for res_block in self.res_blocks:
res_block.remove_weight_norm()
def forward(self, inputs, input_len):
residuals = [
res_block(inputs=inputs, input_len=input_len)
for res_block in self.res_blocks
]
out = sum(residuals) / len(residuals)
return out
class HiFiGANEncoder(nn.Module):
"""
Audio encoder created by inverting the HiFi-GAN decoder.
Args:
encoded_dim: Dimension of encoder output.
down_sample_rates: Rate to upsample for each decoder block. The product of the downsample rates will
determine the output token rate. For example 2 * 2 * 8 * 8 = 256 samples per token.
base_channels: Number of filters in the first convolution. The number of channels will be doubled after each
downsample layer.
in_kernel_size: Kernel size of the input convolution.
out_kernel_size: Kernel size of the output convolution.
resblock_kernel_sizes: List of kernel sizes to use in each residual block.
resblock_dilation_sizes: List of dilations to use in each residual block.
activation: Activation to use in residual and downsample layers, defaults to leaky relu.
"""
def __init__(
self,
encoded_dim: int,
down_sample_rates: Iterable[int] = (2, 2, 8, 8),
base_channels: int = 32,
in_kernel_size: int = 7,
out_kernel_size: int = 7,
resblock_kernel_sizes: Iterable[int] = (3, 7, 11),
resblock_dilation_sizes: Iterable[int] = (1, 3, 5),
activation: str = "lrelu",
):
assert in_kernel_size > 0
assert out_kernel_size > 0
super().__init__()
self.down_sample_rates = down_sample_rates
self.pre_conv = Conv1dNorm(
in_channels=1, out_channels=base_channels, kernel_size=in_kernel_size
)
in_channels = base_channels
self.activations = nn.ModuleList([])
self.down_sample_conv_layers = nn.ModuleList([])
self.res_layers = nn.ModuleList([])
for i, down_sample_rate in enumerate(self.down_sample_rates):
res_layer = HiFiGANResLayer(
channels=in_channels,
kernel_sizes=resblock_kernel_sizes,
dilations=resblock_dilation_sizes,
activation=activation,
)
self.res_layers.append(res_layer)
act = CodecActivation(activation, channels=in_channels)
self.activations.append(act)
out_channels = 2 * in_channels
kernel_size = 2 * down_sample_rate
padding = get_down_sample_padding(
kernel_size=kernel_size, stride=down_sample_rate
)
down_sample_conv = Conv1dNorm(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=down_sample_rate,
padding=padding,
)
in_channels = out_channels
self.down_sample_conv_layers.append(down_sample_conv)
self.post_activation = CodecActivation(activation, channels=in_channels)
self.post_conv = Conv1dNorm(
in_channels=in_channels,
out_channels=encoded_dim,
kernel_size=out_kernel_size,
)
def remove_weight_norm(self):
self.pre_conv.remove_weight_norm()
self.post_conv.remove_weight_norm()
for res_layer in self.res_layers:
res_layer.remove_weight_norm()
for down_sample_conv in self.down_sample_conv_layers:
down_sample_conv.remove_weight_norm()
def forward(self, audio, audio_len):
encoded_len = audio_len
audio = rearrange(audio, "B T -> B 1 T")
# [B, C, T_audio]
out = self.pre_conv(inputs=audio, input_len=encoded_len)
for act, res_layer, down_sample_conv, down_sample_rate in zip(
self.activations,
self.res_layers,
self.down_sample_conv_layers,
self.down_sample_rates,
):
# [B, C, T]
out = res_layer(inputs=out, input_len=encoded_len)
out = act(out)
encoded_len = encoded_len // down_sample_rate
# [B, 2 * C, T / down_sample_rate]
out = down_sample_conv(inputs=out, input_len=encoded_len)
out = self.post_activation(out)
# [B, encoded_dim, T_encoded]
encoded = self.post_conv(inputs=out, input_len=encoded_len)
return encoded, encoded_len
class HiFiGANDecoder(nn.Module):
"""
Codec decoder using the HiFi-GAN generator architecture.
Default parameters match the HiFi-GAN V1 configuration for 22.05khz.
Args:
input_dim: Input dimension.
up_sample_rates: Rate to upsample for each decoder block. The product of the upsample rates should be the same
as the overall downsample rate for your encoder. For example, a symmetric encoder/decoder can be created
with encoder downsample rates [2, 2, 8, 8] and decoder upsample rates [8, 8, 2, 2].
base_channels: Number of filters in the first convolution. The number of channels will be cut in
half after each upsample layer.
in_kernel_size: Kernel size of the input convolution.
out_kernel_size: Kernel size of the output convolution.
resblock_kernel_sizes: List of kernel sizes to use in each residual block.
resblock_dilation_sizes: List of dilations to use in each residual block.
activation: Activation to use in residual and upsample layers, defaults to leaky relu.
output_activation: Activation to apply to output. To produce a valid audio signal, it should output values in
the range [-1.0, 1.0]. Supports "tanh" and "clamp".
"""
def __init__(
self,
input_dim: int,
up_sample_rates: Iterable[int] = (8, 8, 2, 2),
base_channels: int = 512,
in_kernel_size: int = 7,
out_kernel_size: int = 3,
resblock_kernel_sizes: Iterable[int] = (3, 7, 11),
resblock_dilation_sizes: Iterable[int] = (1, 3, 5),
activation: str = "lrelu",
output_activation: str = "tanh",
):
assert in_kernel_size > 0
assert out_kernel_size > 0
super().__init__()
self.up_sample_rates = up_sample_rates
self.pre_conv = Conv1dNorm(
in_channels=input_dim,
out_channels=base_channels,
kernel_size=in_kernel_size,
)
in_channels = base_channels
self.activations = nn.ModuleList([])
self.up_sample_conv_layers = nn.ModuleList([])
self.res_layers = nn.ModuleList([])
for i, up_sample_rate in enumerate(self.up_sample_rates):
out_channels = in_channels // 2
kernel_size = 2 * up_sample_rate
act = CodecActivation(activation, channels=in_channels)
self.activations.append(act)
up_sample_conv = ConvTranspose1dNorm(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=up_sample_rate,
)
in_channels = out_channels
self.up_sample_conv_layers.append(up_sample_conv)
res_layer = HiFiGANResLayer(
channels=in_channels,
kernel_sizes=resblock_kernel_sizes,
dilations=resblock_dilation_sizes,
activation=activation,
)
self.res_layers.append(res_layer)
self.post_activation = CodecActivation(activation, channels=in_channels)
self.post_conv = Conv1dNorm(
in_channels=in_channels, out_channels=1, kernel_size=out_kernel_size
)
if output_activation == "tanh":
self.out_activation = nn.Tanh()
elif output_activation == "clamp":
self.out_activation = ClampActivation()
else:
raise ValueError(f"Invalid audio output activation {output_activation}")
def remove_weight_norm(self):
self.pre_conv.remove_weight_norm()
for up_sample_conv in self.up_sample_conv_layers:
up_sample_conv.remove_weight_norm()
for res_layer in self.res_layers:
res_layer.remove_weight_norm()
def forward(self, inputs, input_len):
audio_len = input_len
# [B, C, T_encoded]
out = self.pre_conv(inputs=inputs, input_len=audio_len)
for act, res_layer, up_sample_conv, up_sample_rate in zip(
self.activations,
self.res_layers,
self.up_sample_conv_layers,
self.up_sample_rates,
):
audio_len = audio_len * up_sample_rate
out = act(out)
# [B, C / 2, T * up_sample_rate]
out = up_sample_conv(inputs=out, input_len=audio_len)
out = res_layer(inputs=out, input_len=audio_len)
out = self.post_activation(out)
# [B, 1, T_audio]
out = self.post_conv(inputs=out, input_len=audio_len)
audio = self.out_activation(out)
audio = rearrange(audio, "B 1 T -> B T")
return audio, audio_len
@torch.jit.script_if_tracing
def make_seq_mask_like(
lengths: Tensor,
like: Tensor,
time_dim: int = -1,
valid_ones: bool = True,
) -> Tensor:
"""
Args:
lengths: Tensor with shape [B] containing the sequence length of each batch element
like: The mask will contain the same number of dimensions as this Tensor, and will have the same max
length in the time dimension of this Tensor.
time_dim: Time dimension of the `shape_tensor` and the resulting mask. Zero-based.
valid_ones: If True, valid tokens will contain value `1` and padding will be `0`. Else, invert.
Returns:
A :class:`Tensor` containing 1's and 0's for valid and invalid tokens, respectively, if `valid_ones`, else
vice-versa. Mask will have the same number of dimensions as `like`. Batch and time dimensions will match
the `like`. All other dimensions will be singletons. E.g., if `like.shape == [3, 4, 5]` and
`time_dim == -1', mask will have shape `[3, 1, 5]`.
"""
# Mask with shape [B, T]
mask = (
torch.arange(like.shape[time_dim], device=like.device)
.repeat(lengths.shape[0], 1)
.lt(lengths.view(-1, 1))
)
# [B, T] -> [B, *, T] where * is any number of singleton dimensions to expand to like tensor
for _ in range(like.dim() - mask.dim()):
mask = mask.unsqueeze(1)
# If needed, transpose time dim
if time_dim != -1 and time_dim != mask.dim() - 1:
mask = mask.transpose(-1, time_dim)
# Maybe invert the padded vs. valid token values
if not valid_ones:
mask = ~mask
return mask
def normalize_batch(x, seq_len, normalize_type):
x_mean = None
x_std = None
if normalize_type == "per_feature":
batch_size = x.shape[0]
max_time = x.shape[2]
# When doing stream capture to a graph, item() is not allowed
# becuase it calls cudaStreamSynchronize(). Therefore, we are
# sacrificing some error checking when running with cuda graphs.
if (
torch.cuda.is_available()
and not torch.cuda.is_current_stream_capturing()
and torch.any(seq_len == 1).item()
):
raise ValueError(
"normalize_batch with `per_feature` normalize_type received a tensor of length 1. This will result "
"in torch.std() returning nan. Make sure your audio length has enough samples for a single "
"feature (ex. at least `hop_length` for Mel Spectrograms)."
)
time_steps = (
torch.arange(max_time, device=x.device)
.unsqueeze(0)
.expand(batch_size, max_time)
)
valid_mask = time_steps < seq_len.unsqueeze(1)
x_mean_numerator = torch.where(valid_mask.unsqueeze(1), x, 0.0).sum(axis=2)
x_mean_denominator = valid_mask.sum(axis=1)
x_mean = x_mean_numerator / x_mean_denominator.unsqueeze(1)
# Subtract 1 in the denominator to correct for the bias.
x_std = torch.sqrt(
torch.sum(
torch.where(valid_mask.unsqueeze(1), x - x_mean.unsqueeze(2), 0.0) ** 2,
axis=2,
)
/ (x_mean_denominator.unsqueeze(1) - 1.0)
)
# make sure x_std is not zero
x_std += CONSTANT
return (x - x_mean.unsqueeze(2)) / x_std.unsqueeze(2), x_mean, x_std
elif normalize_type == "all_features":
x_mean = torch.zeros(seq_len.shape, dtype=x.dtype, device=x.device)
x_std = torch.zeros(seq_len.shape, dtype=x.dtype, device=x.device)
for i in range(x.shape[0]):
x_mean[i] = x[i, :, : seq_len[i].item()].mean()
x_std[i] = x[i, :, : seq_len[i].item()].std()
# make sure x_std is not zero
x_std += CONSTANT
return (x - x_mean.view(-1, 1, 1)) / x_std.view(-1, 1, 1), x_mean, x_std
elif "fixed_mean" in normalize_type and "fixed_std" in normalize_type:
x_mean = torch.tensor(normalize_type["fixed_mean"], device=x.device)
x_std = torch.tensor(normalize_type["fixed_std"], device=x.device)
return (
(x - x_mean.view(x.shape[0], x.shape[1]).unsqueeze(2))
/ x_std.view(x.shape[0], x.shape[1]).unsqueeze(2),
x_mean,
x_std,
)
else:
return x, x_mean, x_std
def splice_frames(x, frame_splicing):
"""Stacks frames together across feature dim
input is batch_size, feature_dim, num_frames
output is batch_size, feature_dim*frame_splicing, num_frames
"""
seq = [x]
for n in range(1, frame_splicing):
seq.append(torch.cat([x[:, :, :n], x[:, :, n:]], dim=2))
return torch.cat(seq, dim=1)
class FilterbankFeatures(nn.Module):
"""Featurizer that converts wavs to Mel Spectrograms.
See AudioToMelSpectrogramPreprocessor for args.
"""
def __init__(
self,
sample_rate=16000,
n_window_size=320,
n_window_stride=160,
window="hann",
normalize="per_feature",
n_fft=None,
preemph=0.97,
nfilt=64,
lowfreq=0,
highfreq=None,
log=True,
log_zero_guard_type="add",
log_zero_guard_value=2**-24,
dither=CONSTANT,
pad_to=16,
max_duration=16.7,
frame_splicing=1,
exact_pad=False,
pad_value=0,
mag_power=2.0,
use_grads=False,
rng=None,
nb_augmentation_prob=0.0,
nb_max_freq=4000,
mel_norm="slaney",
stft_exact_pad=False, # Deprecated arguments; kept for config compatibility
stft_conv=False, # Deprecated arguments; kept for config compatibility
):
super().__init__()
if stft_conv or stft_exact_pad:
logging.warning(
"Using torch_stft is deprecated and has been removed. The values have been forcibly set to False "
"for FilterbankFeatures and AudioToMelSpectrogramPreprocessor. Please set exact_pad to True "
"as needed."
)
if exact_pad and n_window_stride % 2 == 1:
raise NotImplementedError(
f"{self} received exact_pad == True, but hop_size was odd. If audio_length % hop_size == 0. Then the "
"returned spectrogram would not be of length audio_length // hop_size. Please use an even hop_size."
)
self.log_zero_guard_value = log_zero_guard_value
if (
n_window_size is None
or n_window_stride is None
or not isinstance(n_window_size, int)
or not isinstance(n_window_stride, int)
or n_window_size <= 0
or n_window_stride <= 0
):
raise ValueError(
f"{self} got an invalid value for either n_window_size or "
f"n_window_stride. Both must be positive ints."
)
logging.info(f"PADDING: {pad_to}")
self.win_length = n_window_size
self.hop_length = n_window_stride
self.n_fft = n_fft or 2 ** math.ceil(math.log2(self.win_length))
self.stft_pad_amount = (
(self.n_fft - self.hop_length) // 2 if exact_pad else None
)
self.exact_pad = exact_pad
if exact_pad:
logging.info("STFT using exact pad")
torch_windows = {
"hann": torch.hann_window,
"hamming": torch.hamming_window,
"blackman": torch.blackman_window,
"bartlett": torch.bartlett_window,
"none": None,
}
window_fn = torch_windows.get(window, None)
window_tensor = (
window_fn(self.win_length, periodic=False) if window_fn else None
)
self.register_buffer("window", window_tensor)
self.normalize = normalize
self.log = log
self.dither = dither
self.frame_splicing = frame_splicing
self.nfilt = nfilt
self.preemph = preemph
self.pad_to = pad_to
highfreq = highfreq or sample_rate / 2
import librosa
filterbanks = torch.tensor(
librosa.filters.mel(
sr=sample_rate,
n_fft=self.n_fft,
n_mels=nfilt,
fmin=lowfreq,
fmax=highfreq,
norm=mel_norm,
),
dtype=torch.float,
).unsqueeze(0)
self.register_buffer("fb", filterbanks)
# Calculate maximum sequence length
max_length = self.get_seq_len(
torch.tensor(max_duration * sample_rate, dtype=torch.float)
)
max_pad = pad_to - (max_length % pad_to) if pad_to > 0 else 0
self.max_length = max_length + max_pad
self.pad_value = pad_value
self.mag_power = mag_power
# We want to avoid taking the log of zero
# There are two options: either adding or clamping to a small value
if log_zero_guard_type not in ["add", "clamp"]:
raise ValueError(
f"{self} received {log_zero_guard_type} for the "
f"log_zero_guard_type parameter. It must be either 'add' or "
f"'clamp'."
)
self.use_grads = use_grads
if not use_grads:
self.forward = torch.no_grad()(self.forward)
self._rng = random.Random() if rng is None else rng
self.nb_augmentation_prob = nb_augmentation_prob
if self.nb_augmentation_prob > 0.0:
if nb_max_freq >= sample_rate / 2:
self.nb_augmentation_prob = 0.0
else:
self._nb_max_fft_bin = int((nb_max_freq / sample_rate) * n_fft)
# log_zero_guard_value is the the small we want to use, we support
# an actual number, or "tiny", or "eps"
self.log_zero_guard_type = log_zero_guard_type
logging.debug(f"sr: {sample_rate}")
logging.debug(f"n_fft: {self.n_fft}")
logging.debug(f"win_length: {self.win_length}")
logging.debug(f"hop_length: {self.hop_length}")
logging.debug(f"n_mels: {nfilt}")
logging.debug(f"fmin: {lowfreq}")
logging.debug(f"fmax: {highfreq}")
logging.debug(f"using grads: {use_grads}")
logging.debug(f"nb_augmentation_prob: {nb_augmentation_prob}")
def stft(self, x):
return torch.stft(
x,
n_fft=self.n_fft,
hop_length=self.hop_length,
win_length=self.win_length,
center=False if self.exact_pad else True,
window=self.window.to(dtype=torch.float),
return_complex=True,
)
def log_zero_guard_value_fn(self, x):
if isinstance(self.log_zero_guard_value, str):
if self.log_zero_guard_value == "tiny":
return torch.finfo(x.dtype).tiny
elif self.log_zero_guard_value == "eps":
return torch.finfo(x.dtype).eps
else:
raise ValueError(
f"{self} received {self.log_zero_guard_value} for the "
f"log_zero_guard_type parameter. It must be either a "
f"number, 'tiny', or 'eps'"
)
else:
return self.log_zero_guard_value
def get_seq_len(self, seq_len):
# Assuming that center is True is stft_pad_amount = 0
pad_amount = (
self.stft_pad_amount * 2
if self.stft_pad_amount is not None
else self.n_fft // 2 * 2
)
seq_len = (
torch.floor_divide((seq_len + pad_amount - self.n_fft), self.hop_length) + 1
)
return seq_len.to(dtype=torch.long)
@property
def filter_banks(self):
return self.fb
def forward(self, x, seq_len, linear_spec=False):
seq_len = self.get_seq_len(seq_len)
if self.stft_pad_amount is not None:
x = torch.nn.functional.pad(
x.unsqueeze(1), (self.stft_pad_amount, self.stft_pad_amount), "reflect"
).squeeze(1)
# dither (only in training mode for eval determinism)
if self.training and self.dither > 0:
x += self.dither * torch.randn_like(x)
# do preemphasis
if self.preemph is not None:
x = torch.cat(
(x[:, 0].unsqueeze(1), x[:, 1:] - self.preemph * x[:, :-1]), dim=1
)
# disable autocast to get full range of stft values
with torch.amp.autocast(x.device.type, enabled=False):
x = self.stft(x)
# torch stft returns complex tensor (of shape [B,N,T]); so convert to magnitude
# guard is needed for sqrt if grads are passed through
guard = 0 if not self.use_grads else CONSTANT
x = torch.view_as_real(x)
x = torch.sqrt(x.pow(2).sum(-1) + guard)
if self.training and self.nb_augmentation_prob > 0.0:
for idx in range(x.shape[0]):
if self._rng.random() < self.nb_augmentation_prob:
x[idx, self._nb_max_fft_bin :, :] = 0.0
# get power spectrum
if self.mag_power != 1.0:
x = x.pow(self.mag_power)
# return plain spectrogram if required
if linear_spec:
return x, seq_len
# dot with filterbank energies
x = torch.matmul(self.fb.to(x.dtype), x)
# log features if required
if self.log:
if self.log_zero_guard_type == "add":
x = torch.log(x + self.log_zero_guard_value_fn(x))
elif self.log_zero_guard_type == "clamp":
x = torch.log(torch.clamp(x, min=self.log_zero_guard_value_fn(x)))
else:
raise ValueError("log_zero_guard_type was not understood")
# frame splicing if required
if self.frame_splicing > 1:
x = splice_frames(x, self.frame_splicing)
# normalize if required
if self.normalize:
x, _, _ = normalize_batch(x, seq_len, normalize_type=self.normalize)
# mask to zero any values beyond seq_len in batch, pad to multiple of `pad_to` (for efficiency)
max_len = x.size(-1)
mask = torch.arange(max_len, device=x.device)
mask = mask.repeat(x.size(0), 1) >= seq_len.unsqueeze(1)
x = x.masked_fill(
mask.unsqueeze(1).type(torch.bool).to(device=x.device), self.pad_value
)
del mask
pad_to = self.pad_to
if pad_to == "max":
x = nn.functional.pad(
x, (0, self.max_length - x.size(-1)), value=self.pad_value
)
elif pad_to > 0:
pad_amt = x.size(-1) % pad_to
if pad_amt != 0:
x = nn.functional.pad(x, (0, pad_to - pad_amt), value=self.pad_value)
return x, seq_len
class AudioToMelSpectrogramPreprocessor(nn.Module):
"""Featurizer module that converts wavs to mel spectrograms.
Args:
sample_rate (int): Sample rate of the input audio data.
Defaults to 16000
window_size (float): Size of window for fft in seconds
Defaults to 0.02
window_stride (float): Stride of window for fft in seconds
Defaults to 0.01
n_window_size (int): Size of window for fft in samples
Defaults to None. Use one of window_size or n_window_size.
n_window_stride (int): Stride of window for fft in samples
Defaults to None. Use one of window_stride or n_window_stride.
window (str): Windowing function for fft. can be one of ['hann',
'hamming', 'blackman', 'bartlett']
Defaults to "hann"
normalize (str): Can be one of ['per_feature', 'all_features']; all
other options disable feature normalization. 'all_features'
normalizes the entire spectrogram to be mean 0 with std 1.
'pre_features' normalizes per channel / freq instead.
Defaults to "per_feature"
n_fft (int): Length of FT window. If None, it uses the smallest power
of 2 that is larger than n_window_size.
Defaults to None
preemph (float): Amount of pre emphasis to add to audio. Can be
disabled by passing None.
Defaults to 0.97
features (int): Number of mel spectrogram freq bins to output.
Defaults to 64
lowfreq (int): Lower bound on mel basis in Hz.
Defaults to 0
highfreq (int): Lower bound on mel basis in Hz.
Defaults to None
log (bool): Log features.
Defaults to True
log_zero_guard_type(str): Need to avoid taking the log of zero. There
are two options: "add" or "clamp".
Defaults to "add".
log_zero_guard_value(float, or str): Add or clamp requires the number
to add with or clamp to. log_zero_guard_value can either be a float
or "tiny" or "eps". torch.finfo is used if "tiny" or "eps" is
passed.
Defaults to 2**-24.
dither (float): Amount of white-noise dithering.
Defaults to 1e-5
pad_to (int): Ensures that the output size of the time dimension is
a multiple of pad_to.
Defaults to 16
frame_splicing (int): Defaults to 1
exact_pad (bool): If True, sets stft center to False and adds padding, such that num_frames = audio_length
// hop_length. Defaults to False.
pad_value (float): The value that shorter mels are padded with.
Defaults to 0
mag_power (float): The power that the linear spectrogram is raised to
prior to multiplication with mel basis.
Defaults to 2 for a power spec
rng : Random number generator
nb_augmentation_prob (float) : Probability with which narrowband augmentation would be applied to
samples in the batch.
Defaults to 0.0
nb_max_freq (int) : Frequency above which all frequencies will be masked for narrowband augmentation.
Defaults to 4000
use_torchaudio: Whether to use the `torchaudio` implementation.
mel_norm: Normalization used for mel filterbank weights.
Defaults to 'slaney' (area normalization)
stft_exact_pad: Deprecated argument, kept for compatibility with older checkpoints.
stft_conv: Deprecated argument, kept for compatibility with older checkpoints.
"""
def __init__(
self,
sample_rate=16000,
window_size=0.02,
window_stride=0.01,
n_window_size=None,
n_window_stride=None,
window="hann",
normalize="per_feature",
n_fft=None,
preemph=0.97,
features=64,
lowfreq=0,
highfreq=None,
log=True,
log_zero_guard_type="add",
log_zero_guard_value=2**-24,
dither=1e-5,
pad_to=16,
frame_splicing=1,
exact_pad=False,
pad_value=0,
mag_power=2.0,
rng=None,
nb_augmentation_prob=0.0,
nb_max_freq=4000,
use_torchaudio: bool = False,
mel_norm="slaney",
stft_exact_pad=False, # Deprecated arguments; kept for config compatibility
stft_conv=False, # Deprecated arguments; kept for config compatibility
):
super().__init__(n_window_size, n_window_stride)
self._sample_rate = sample_rate
if window_size and n_window_size:
raise ValueError(
f"{self} received both window_size and "
f"n_window_size. Only one should be specified."
)
if window_stride and n_window_stride:
raise ValueError(
f"{self} received both window_stride and "
f"n_window_stride. Only one should be specified."
)
if window_size:
n_window_size = int(window_size * self._sample_rate)
if window_stride:
n_window_stride = int(window_stride * self._sample_rate)
# Given the long and similar argument list, point to the class and instantiate it by reference
featurizer_class = FilterbankFeatures
self.featurizer = featurizer_class(
sample_rate=self._sample_rate,
n_window_size=n_window_size,
n_window_stride=n_window_stride,
window=window,
normalize=normalize,
n_fft=n_fft,
preemph=preemph,
nfilt=features,
lowfreq=lowfreq,
highfreq=highfreq,
log=log,
log_zero_guard_type=log_zero_guard_type,
log_zero_guard_value=log_zero_guard_value,
dither=dither,
pad_to=pad_to,
frame_splicing=frame_splicing,
exact_pad=exact_pad,
pad_value=pad_value,
mag_power=mag_power,
rng=rng,
nb_augmentation_prob=nb_augmentation_prob,
nb_max_freq=nb_max_freq,
mel_norm=mel_norm,
stft_exact_pad=stft_exact_pad, # Deprecated arguments; kept for config compatibility
stft_conv=stft_conv, # Deprecated arguments; kept for config compatibility
)
def input_example(
self, max_batch: int = 8, max_dim: int = 32000, min_length: int = 200
):
batch_size = torch.randint(low=1, high=max_batch, size=[1]).item()
max_length = torch.randint(low=min_length, high=max_dim, size=[1]).item()
signals = torch.rand(size=[batch_size, max_length]) * 2 - 1
lengths = torch.randint(low=min_length, high=max_dim, size=[batch_size])
lengths[0] = max_length
return signals, lengths
def get_features(self, input_signal, length):
return self.featurizer(input_signal, length)
@property
def filter_banks(self):
return self.featurizer.filter_banks
class MelSpectrogramProcessor(nn.Module):
"""
Wrapper interface for computing mel spectrogram for codec training.
"""
def __init__(
self,
sample_rate: int,
win_length: int,
hop_length: int,
mel_dim: int = 80,
log_guard: float = 1.0,
):
super(MelSpectrogramProcessor, self).__init__()
self.mel_dim = mel_dim
self.hop_length = hop_length
self.preprocessor = AudioToMelSpectrogramPreprocessor(
sample_rate=sample_rate,
highfreq=None,
features=mel_dim,
pad_to=1,
exact_pad=True,
n_window_size=win_length,
n_window_stride=hop_length,
window_size=False,
window_stride=False,
n_fft=win_length,
mag_power=1.0,
log=True,
log_zero_guard_type="add",
log_zero_guard_value=log_guard,
mel_norm=None,
normalize=None,
preemph=None,
dither=0.0,
)
def forward(self, audio, audio_len):
spec, spec_len = self.preprocessor(input_signal=audio, length=audio_len)
return spec, spec_len
class ResNetEncoder(nn.Module):
"""
Residual network which uses HiFi-GAN residual blocks to encode spectrogram features without changing
the time dimension.
Args:
in_channels: input dimension
out_channels: output dimension
num_layers: number of residual blocks to use
hidden_channels: encoder hidden dimension
filters: number of filters in residual block layers
kernel_size: kernel size in residual block convolutions
dropout_rate: Optional dropout rate to apply to residuals.
activation: Activation to use, defaults to leaky relu.
"""
def __init__(
self,
in_channels: int,
out_channels: int,
num_layers: int = 6,
hidden_channels: int = 256,
filters: int = 768,
kernel_size: int = 3,
dropout_rate: float = 0.1,
activation: str = "lrelu",
):
super(ResNetEncoder, self).__init__()
self.pre_conv = Conv1dNorm(
in_channels=in_channels,
out_channels=hidden_channels,
kernel_size=kernel_size,
)
self.res_layers = nn.ModuleList(
[
ResidualBlock(
channels=hidden_channels,
filters=filters,
kernel_size=kernel_size,
dropout_rate=dropout_rate,
activation=activation,
)
for _ in range(num_layers)
]
)
self.post_activation = CodecActivation(activation, channels=hidden_channels)
self.post_conv = Conv1dNorm(
in_channels=hidden_channels,
out_channels=out_channels,
kernel_size=kernel_size,
)
def remove_weight_norm(self):
self.pre_conv.remove_weight_norm()
self.post_conv.remove_weight_norm()
for res_layer in self.res_layers:
res_layer.remove_weight_norm()
def forward(self, inputs, input_len):
encoded = self.pre_conv(inputs=inputs, input_len=input_len)
for res_layer in self.res_layers:
encoded = res_layer(inputs=encoded, input_len=input_len)
encoded = self.post_activation(encoded)
encoded = self.post_conv(inputs=encoded, input_len=input_len)
return encoded
class FullBandMelEncoder(nn.Module):
"""
Encoder which encodes the entire mel spectrogram with a single encoder network.
Args:
mel_processor: MelSpectrogramProcessor or equivalent class instance for computing the mel spectrogram from
input audio.
encoder: ResNetEncoder or equivalent class for encoding the mel spectrogram.
"""
def __init__(self, mel_processor: nn.Module, encoder: nn.Module):
super(FullBandMelEncoder, self).__init__()
self.mel_processor = mel_processor
self.encoder = encoder
def remove_weight_norm(self):
self.encoder.remove_weight_norm()
def forward(self, audio, audio_len):
out, spec_len = self.mel_processor(audio=audio, audio_len=audio_len)
encoded = self.encoder(inputs=out, input_len=spec_len)
return encoded, spec_len
class MultiBandMelEncoder(nn.Module):
"""
Encoder which splits mel spectrogram into bands and encodes each using separate residual networks.
Args:
mel_bands: List of mel spectrogram bands to encode.
Each list element is tuple of 2 elements with the start and end index of the mel features to use.
mel_processor: MelSpectrogramProcessor or equivalent class instance for computing the mel spectrogram from
input audio.
encoder_kwargs: Arguments for constructing encoder for each mel band.
"""
def __init__(
self,
mel_bands: Iterable[Tuple[int, int]],
mel_processor: nn.Module,
**encoder_kwargs,
):
super(MultiBandMelEncoder, self).__init__()
self.validate_mel_bands(mel_dim=mel_processor.mel_dim, mel_bands=mel_bands)
self.mel_bands = mel_bands
self.mel_processor = mel_processor
band_dims = [band[1] - band[0] for band in self.mel_bands]
self.encoders = nn.ModuleList(
[
ResNetEncoder(in_channels=band_dim, **encoder_kwargs)
for band_dim in band_dims
]
)
@staticmethod
def validate_mel_bands(mel_dim: int, mel_bands: Iterable[Tuple[int, int]]):
mel_dims_used = np.zeros([mel_dim], dtype=bool)
for band in mel_bands:
mel_dims_used[band[0] : band[1]] = True
if not all(mel_dims_used):
missing_dims = np.where(~mel_dims_used)
raise ValueError(
f"Mel bands must cover all {mel_dim} dimensions. Missing {missing_dims}."
)
return
def remove_weight_norm(self):
for encoder in self.encoders:
encoder.remove_weight_norm()
def forward(self, audio, audio_len):
spec, spec_len = self.mel_processor(audio=audio, audio_len=audio_len)
outputs = []
for (band_start, band_end), encoder in zip(self.mel_bands, self.encoders):
# [B, D_band, T]
spec_band = spec[:, band_start:band_end, :]
band_out = encoder(inputs=spec_band, input_len=spec_len)
outputs.append(band_out)
# [B, C, T]
encoded = torch.cat(outputs, dim=1)
return encoded, spec_len
import einops
import torch
import torch.nn as nn
activation_registry = {
"identity": nn.Identity,
"hardtanh": nn.Hardtanh,
"relu": nn.ReLU,
"selu": nn.SELU,
"swish": nn.SiLU,
"silu": nn.SiLU,
"gelu": nn.GELU,
}
def mask_sequence_tensor(tensor: torch.Tensor, lengths: torch.Tensor):
"""
For tensors containing sequences, zero out out-of-bound elements given lengths of every element in the batch.
tensor: tensor of shape (B, L), (B, D, L) or (B, D1, D2, L),
lengths: LongTensor of shape (B,)
"""
batch_size, *_, max_lengths = tensor.shape
if len(tensor.shape) == 2:
mask = torch.ones(batch_size, max_lengths).cumsum(dim=-1).type_as(lengths)
mask = mask <= einops.rearrange(lengths, "B -> B 1")
elif len(tensor.shape) == 3:
mask = torch.ones(batch_size, 1, max_lengths).cumsum(dim=-1).type_as(lengths)
mask = mask <= einops.rearrange(lengths, "B -> B 1 1")
elif len(tensor.shape) == 4:
mask = torch.ones(batch_size, 1, 1, max_lengths).cumsum(dim=-1).type_as(lengths)
mask = mask <= einops.rearrange(lengths, "B -> B 1 1 1")
else:
raise ValueError(
"Can only mask tensors of shape B x L, B x D x L and B x D1 x D2 x L"
)
return tensor * mask
class ClampActivation(nn.Module):
def __init__(self, min_value: float = -1.0, max_value: float = 1.0):
super().__init__()
self.min_value = min_value
self.max_value = max_value
def forward(self, input: torch.Tensor) -> torch.Tensor:
return torch.clamp(input, min=self.min_value, max=self.max_value)
@torch.jit.script
def snake(x: torch.Tensor, alpha: torch.Tensor, eps: float = 1e-9) -> torch.Tensor:
"""
equation for snake activation function: x + (alpha + eps)^-1 * sin(alpha * x)^2
"""
shape = x.shape
x = x.reshape(shape[0], shape[1], -1)
x = x + (alpha + eps).reciprocal() * torch.sin(alpha * x).pow(2)
x = x.reshape(shape)
return x
class Snake(nn.Module):
"""
Snake activation function introduced in 'https://arxiv.org/abs/2006.08195'
"""
def __init__(self, channels: int):
super().__init__()
self.alpha = nn.Parameter(torch.ones(1, channels, 1))
def forward(self, x: torch.Tensor) -> torch.Tensor:
return snake(x, self.alpha)
class HalfSnake(nn.Module):
"""
Activation which applies snake to the first half of input elements and leaky relu to the second half.
"""
def __init__(self, channels: int):
super().__init__()
self.snake_channels = channels // 2
self.snake_act = Snake(self.snake_channels)
self.lrelu = torch.nn.LeakyReLU()
def forward(self, x: torch.Tensor) -> torch.Tensor:
snake_out = self.snake_act(x[:, : self.snake_channels, :])
lrelu_out = self.lrelu(x[:, self.snake_channels :, :])
out = torch.cat([snake_out, lrelu_out], dim=1)
return out
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment