Skip to content

Instantly share code, notes, and snippets.

@Ishotihadus
Last active September 19, 2025 05:30
Show Gist options
  • Save Ishotihadus/84b80ca795067c04e1b8e6e4dce6fe0c to your computer and use it in GitHub Desktop.
Save Ishotihadus/84b80ca795067c04e1b8e6e4dce6fe0c to your computer and use it in GitHub Desktop.
A GPU-friendly wrapper of USM in Google Gemma 3n
# sudo apt install libsox-dev
# mkdir -p ~/.cache/huggingface && echo 'hf_xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx' > ~/.cache/huggingface/token
# uv run --with "transformers,torch,torchaudio<2.9" python
import torchaudio
from google_usm import USMWrapper
usm = USMWrapper("google/gemma-3n-E4B")
wav, sr = torchaudio.load("VOICEACTRESS100_001.wav")
wav = wav.cuda() # Available
features = usm(wav, sr)
# Input must be either of
# - torch.FloatTensor[n_samples],
# - torch.FloatTensor[n_batches, n_samples],
# - list[torch.FloatTensor[n_samples]]
# features = usm([wav.squeeze(0)], sr)
# features = usm(wav.squeeze(0), sr)
features.last_hidden_state
# => tensor([[[ 0.0038, -0.0037, -0.0049, ..., -0.0050, -0.0004, -0.0032],
# [-0.0076, 0.0020, 0.0171, ..., -0.0107, -0.0007, 0.0070],
# [-0.0218, -0.0018, 0.0051, ..., -0.0096, -0.0008, 0.0052],
# ...,
# [ 0.0037, 0.0019, 0.0043, ..., 0.0164, 0.0194, -0.0064],
# [-0.0139, 0.0053, 0.0017, ..., -0.0194, 0.0158, -0.0051],
# [-0.0273, 0.0084, 0.0137, ..., 0.0173, 0.0176, -0.0172]]],
# grad_fn=<MulBackward0>)
# Apache-2.0 License
# (c) 2018- The Hugging Face team
# (c) 2025 Ishotihadus
import math
from dataclasses import dataclass
import torch
from torch import nn
from torch.nn.utils.rnn import pack_sequence, pad_packed_sequence
from torchaudio.sox_effects import apply_effects_tensor
from torchaudio.transforms import MelScale
from transformers.modeling_utils import PreTrainedModel
from transformers.models.gemma3n import Gemma3nConfig, Gemma3nAudioConfig, Gemma3nPreTrainedModel, \
Gemma3nAudioFeatureExtractor
from transformers.models.gemma3n.modeling_gemma3n import Gemma3nAudioSubSampleConvProjection, Gemma3nAudioConformerBlock
from transformers.utils import ModelOutput
@dataclass
class USMModelOutput(ModelOutput):
last_hidden_state: torch.FloatTensor | None = None
hidden_states: tuple[torch.FloatTensor, ...] | None = None
attentions: tuple[torch.FloatTensor, ...] | None = None
mask: torch.BoolTensor | None = None
class USMModelInternal(PreTrainedModel):
config_class = Gemma3nAudioConfig
main_input_name = "input_features"
def __init__(self, config: Gemma3nAudioConfig):
super().__init__(config)
self.config = config
self.subsample_conv_projection = Gemma3nAudioSubSampleConvProjection(config)
self.conformer = nn.ModuleList(
[Gemma3nAudioConformerBlock(config) for _ in range(config.conf_num_hidden_layers)]
)
def forward(self, audio_mel: torch.Tensor, audio_mel_mask: torch.BoolTensor,
output_hidden_states: bool, output_attentions: bool) -> dict:
hidden_states = [] if output_hidden_states else None
attentions = [] if output_attentions else None
audio_encodings = self.subsample_conv_projection(audio_mel)
if hidden_states is not None:
hidden_states.append(audio_encodings)
t_sub = audio_encodings.shape[1]
time_stride_product = 1
for stride_pair_idx in range(len(self.config.sscp_conv_stride_size)):
time_stride_product *= self.config.sscp_conv_stride_size[stride_pair_idx][0]
indices = torch.arange(t_sub, device=audio_mel_mask.device) * time_stride_product
indices = torch.clamp(indices, max=audio_mel_mask.shape[1] - 1)
if audio_mel_mask.ndim > 1 and indices.ndim == 1:
indices = indices.unsqueeze(0).expand(audio_mel_mask.shape[0], -1)
elif audio_mel_mask.ndim == indices.ndim and audio_mel_mask.shape[0] == 1 and \
indices.shape[0] != 1 and t_sub == indices.shape[0]:
indices = indices.unsqueeze(0)
current_mask = torch.gather(audio_mel_mask, 1, indices)
for block in self.conformer:
audio_encodings = block(audio_encodings, current_mask)
if hidden_states is not None:
hidden_states.append(audio_encodings)
if attentions is not None:
attentions.append(None)
if hidden_states is not None:
hidden_states = tuple(hidden_states)
if attentions is not None:
attentions = tuple(attentions)
return USMModelOutput(audio_encodings, hidden_states, attentions, current_mask)
class USMModel(Gemma3nPreTrainedModel):
_checkpoint_conversion_mapping = {}
accepts_loss_kwargs = False
base_model_prefix = "model"
def __init__(self, config: Gemma3nConfig):
super().__init__(config.audio_config)
self.audio_tower = USMModelInternal._from_config(config.audio_config)
self.post_init()
def forward(self, input_features: torch.Tensor, input_features_mask: torch.BoolTensor = None,
output_hidden_states: bool = False, output_attentions: bool = False) -> USMModelOutput:
audio_mel = input_features
if input_features_mask is None:
audio_mel_mask = torch.zeros(input_features.shape[0:2], dtype=torch.bool, device=input_features.device)
else:
audio_mel_mask = ~input_features_mask
return self.audio_tower(audio_mel, audio_mel_mask, output_hidden_states, output_attentions)
class USMWrapper(nn.Module):
def __init__(self, model: str = 'google/gemma-3n-E4B'):
super().__init__()
feature_extractor = Gemma3nAudioFeatureExtractor.from_pretrained(model)
self.frame_length = feature_extractor.frame_length
self.hop_length = feature_extractor.hop_length
self.fft_length = feature_extractor.fft_length
self.window = nn.Parameter(torch.from_numpy(feature_extractor.window), requires_grad=False)
self.mel_floor = feature_extractor.mel_floor.item()
self.per_bin_mean = feature_extractor.per_bin_mean
self.per_bin_stddev = feature_extractor.per_bin_stddev
self.sample_rate = feature_extractor.sampling_rate
self.mel_scale = MelScale(
n_mels=feature_extractor.feature_size,
sample_rate=self.sample_rate,
f_min=feature_extractor.min_frequency,
f_max=feature_extractor.max_frequency,
n_stft=self.fft_length // 2 + 1
)
self.preemphasis = feature_extractor.preemphasis
self.preemphasis_htk_flavor = feature_extractor.preemphasis_htk_flavor
self.encoder = USMModel.from_pretrained(model, attn_implementation="eager")
self.config = self.encoder.config
self.hidden_size = self.config.hidden_size
encoder_hop_size = math.prod([e[1] for e in self.config.sscp_conv_stride_size])
self.total_hop_size = feature_extractor.hop_length * encoder_hop_size
def forward(self, x: list[torch.FloatTensor] | torch.FloatTensor, sr: int = None, lengths=None,
output_hidden_states: bool = False, output_attentions: bool = False) -> USMModelOutput:
if type(x) is list or type(x) is tuple:
if sr is not None and sr != self.sample_rate:
x_resampled = []
for xx in x:
device = xx.device
xx, _ = apply_effects_tensor(xx.cpu().unsqueeze(0), sr, [["rate", "-vsL", str(self.sample_rate)]])
x_resampled.append(xx.to(device).mean(dim=0))
x = x_resampled
if lengths is not None:
lengths = lengths * sr // self.sample_rate
x, auto_lengths = pad_packed_sequence(pack_sequence(x, enforce_sorted=False), batch_first=True)
if lengths is None:
lengths = auto_lengths
elif x.ndim == 2:
if sr is not None and sr != self.sample_rate:
device = x.device
x, _ = apply_effects_tensor(x.cpu(), sr, [["rate", "-vsL", str(self.sample_rate)]])
x = x.to(device)
if lengths is not None:
lengths = lengths * sr // self.sample_rate
if lengths is None:
lengths = torch.LongTensor([x.size(1) for _ in range(x.size(0))])
else:
x = x.unsqueeze(0)
if sr is not None and sr != self.sample_rate:
device = x.device
x, _ = apply_effects_tensor(x.cpu(), sr, [["rate", "-vsL", str(self.sample_rate)]])
x = x.to(device)
if lengths is not None:
lengths = lengths * sr // self.sample_rate
if lengths is None:
lengths = torch.LongTensor([x.size(1)])
lengths = lengths.to(x.device) // self.hop_length
frames = x.unfold(-1, self.frame_length + 1, self.hop_length)
if self.preemphasis > 0:
if self.preemphasis_htk_flavor:
f1 = frames[..., :1] * (1.0 - self.preemphasis)
f2 = frames[..., 1:-1] - self.preemphasis * frames[..., :-2]
frames = torch.cat((f1, f2), dim=-1)
else:
frames = frames[..., 1:] - self.preemphasis * frames[..., :-1]
else:
frames = frames[..., :-1]
frames = self.window[None, None, :] * frames
features = torch.fft.rfft(frames, n=self.fft_length).abs().transpose(-1, -2)
features = self.mel_scale(features).clamp(min=self.mel_floor).log().transpose(-1, -2)
if self.per_bin_mean is not None:
features = features - self.per_bin_mean
if self.per_bin_stddev is not None:
features = features / self.per_bin_stddev
# False if masked to match the implementation of Wav2Vec2FeatureExtractor
mask = torch.arange(features.shape[1], device=features.device).unsqueeze(0) < lengths.unsqueeze(1)
return self.encoder(features, mask, output_hidden_states, output_attentions)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment