Last active
September 19, 2025 05:30
-
-
Save Ishotihadus/84b80ca795067c04e1b8e6e4dce6fe0c to your computer and use it in GitHub Desktop.
A GPU-friendly wrapper of USM in Google Gemma 3n
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| # sudo apt install libsox-dev | |
| # mkdir -p ~/.cache/huggingface && echo 'hf_xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx' > ~/.cache/huggingface/token | |
| # uv run --with "transformers,torch,torchaudio<2.9" python | |
| import torchaudio | |
| from google_usm import USMWrapper | |
| usm = USMWrapper("google/gemma-3n-E4B") | |
| wav, sr = torchaudio.load("VOICEACTRESS100_001.wav") | |
| wav = wav.cuda() # Available | |
| features = usm(wav, sr) | |
| # Input must be either of | |
| # - torch.FloatTensor[n_samples], | |
| # - torch.FloatTensor[n_batches, n_samples], | |
| # - list[torch.FloatTensor[n_samples]] | |
| # features = usm([wav.squeeze(0)], sr) | |
| # features = usm(wav.squeeze(0), sr) | |
| features.last_hidden_state | |
| # => tensor([[[ 0.0038, -0.0037, -0.0049, ..., -0.0050, -0.0004, -0.0032], | |
| # [-0.0076, 0.0020, 0.0171, ..., -0.0107, -0.0007, 0.0070], | |
| # [-0.0218, -0.0018, 0.0051, ..., -0.0096, -0.0008, 0.0052], | |
| # ..., | |
| # [ 0.0037, 0.0019, 0.0043, ..., 0.0164, 0.0194, -0.0064], | |
| # [-0.0139, 0.0053, 0.0017, ..., -0.0194, 0.0158, -0.0051], | |
| # [-0.0273, 0.0084, 0.0137, ..., 0.0173, 0.0176, -0.0172]]], | |
| # grad_fn=<MulBackward0>) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| # Apache-2.0 License | |
| # (c) 2018- The Hugging Face team | |
| # (c) 2025 Ishotihadus | |
| import math | |
| from dataclasses import dataclass | |
| import torch | |
| from torch import nn | |
| from torch.nn.utils.rnn import pack_sequence, pad_packed_sequence | |
| from torchaudio.sox_effects import apply_effects_tensor | |
| from torchaudio.transforms import MelScale | |
| from transformers.modeling_utils import PreTrainedModel | |
| from transformers.models.gemma3n import Gemma3nConfig, Gemma3nAudioConfig, Gemma3nPreTrainedModel, \ | |
| Gemma3nAudioFeatureExtractor | |
| from transformers.models.gemma3n.modeling_gemma3n import Gemma3nAudioSubSampleConvProjection, Gemma3nAudioConformerBlock | |
| from transformers.utils import ModelOutput | |
| @dataclass | |
| class USMModelOutput(ModelOutput): | |
| last_hidden_state: torch.FloatTensor | None = None | |
| hidden_states: tuple[torch.FloatTensor, ...] | None = None | |
| attentions: tuple[torch.FloatTensor, ...] | None = None | |
| mask: torch.BoolTensor | None = None | |
| class USMModelInternal(PreTrainedModel): | |
| config_class = Gemma3nAudioConfig | |
| main_input_name = "input_features" | |
| def __init__(self, config: Gemma3nAudioConfig): | |
| super().__init__(config) | |
| self.config = config | |
| self.subsample_conv_projection = Gemma3nAudioSubSampleConvProjection(config) | |
| self.conformer = nn.ModuleList( | |
| [Gemma3nAudioConformerBlock(config) for _ in range(config.conf_num_hidden_layers)] | |
| ) | |
| def forward(self, audio_mel: torch.Tensor, audio_mel_mask: torch.BoolTensor, | |
| output_hidden_states: bool, output_attentions: bool) -> dict: | |
| hidden_states = [] if output_hidden_states else None | |
| attentions = [] if output_attentions else None | |
| audio_encodings = self.subsample_conv_projection(audio_mel) | |
| if hidden_states is not None: | |
| hidden_states.append(audio_encodings) | |
| t_sub = audio_encodings.shape[1] | |
| time_stride_product = 1 | |
| for stride_pair_idx in range(len(self.config.sscp_conv_stride_size)): | |
| time_stride_product *= self.config.sscp_conv_stride_size[stride_pair_idx][0] | |
| indices = torch.arange(t_sub, device=audio_mel_mask.device) * time_stride_product | |
| indices = torch.clamp(indices, max=audio_mel_mask.shape[1] - 1) | |
| if audio_mel_mask.ndim > 1 and indices.ndim == 1: | |
| indices = indices.unsqueeze(0).expand(audio_mel_mask.shape[0], -1) | |
| elif audio_mel_mask.ndim == indices.ndim and audio_mel_mask.shape[0] == 1 and \ | |
| indices.shape[0] != 1 and t_sub == indices.shape[0]: | |
| indices = indices.unsqueeze(0) | |
| current_mask = torch.gather(audio_mel_mask, 1, indices) | |
| for block in self.conformer: | |
| audio_encodings = block(audio_encodings, current_mask) | |
| if hidden_states is not None: | |
| hidden_states.append(audio_encodings) | |
| if attentions is not None: | |
| attentions.append(None) | |
| if hidden_states is not None: | |
| hidden_states = tuple(hidden_states) | |
| if attentions is not None: | |
| attentions = tuple(attentions) | |
| return USMModelOutput(audio_encodings, hidden_states, attentions, current_mask) | |
| class USMModel(Gemma3nPreTrainedModel): | |
| _checkpoint_conversion_mapping = {} | |
| accepts_loss_kwargs = False | |
| base_model_prefix = "model" | |
| def __init__(self, config: Gemma3nConfig): | |
| super().__init__(config.audio_config) | |
| self.audio_tower = USMModelInternal._from_config(config.audio_config) | |
| self.post_init() | |
| def forward(self, input_features: torch.Tensor, input_features_mask: torch.BoolTensor = None, | |
| output_hidden_states: bool = False, output_attentions: bool = False) -> USMModelOutput: | |
| audio_mel = input_features | |
| if input_features_mask is None: | |
| audio_mel_mask = torch.zeros(input_features.shape[0:2], dtype=torch.bool, device=input_features.device) | |
| else: | |
| audio_mel_mask = ~input_features_mask | |
| return self.audio_tower(audio_mel, audio_mel_mask, output_hidden_states, output_attentions) | |
| class USMWrapper(nn.Module): | |
| def __init__(self, model: str = 'google/gemma-3n-E4B'): | |
| super().__init__() | |
| feature_extractor = Gemma3nAudioFeatureExtractor.from_pretrained(model) | |
| self.frame_length = feature_extractor.frame_length | |
| self.hop_length = feature_extractor.hop_length | |
| self.fft_length = feature_extractor.fft_length | |
| self.window = nn.Parameter(torch.from_numpy(feature_extractor.window), requires_grad=False) | |
| self.mel_floor = feature_extractor.mel_floor.item() | |
| self.per_bin_mean = feature_extractor.per_bin_mean | |
| self.per_bin_stddev = feature_extractor.per_bin_stddev | |
| self.sample_rate = feature_extractor.sampling_rate | |
| self.mel_scale = MelScale( | |
| n_mels=feature_extractor.feature_size, | |
| sample_rate=self.sample_rate, | |
| f_min=feature_extractor.min_frequency, | |
| f_max=feature_extractor.max_frequency, | |
| n_stft=self.fft_length // 2 + 1 | |
| ) | |
| self.preemphasis = feature_extractor.preemphasis | |
| self.preemphasis_htk_flavor = feature_extractor.preemphasis_htk_flavor | |
| self.encoder = USMModel.from_pretrained(model, attn_implementation="eager") | |
| self.config = self.encoder.config | |
| self.hidden_size = self.config.hidden_size | |
| encoder_hop_size = math.prod([e[1] for e in self.config.sscp_conv_stride_size]) | |
| self.total_hop_size = feature_extractor.hop_length * encoder_hop_size | |
| def forward(self, x: list[torch.FloatTensor] | torch.FloatTensor, sr: int = None, lengths=None, | |
| output_hidden_states: bool = False, output_attentions: bool = False) -> USMModelOutput: | |
| if type(x) is list or type(x) is tuple: | |
| if sr is not None and sr != self.sample_rate: | |
| x_resampled = [] | |
| for xx in x: | |
| device = xx.device | |
| xx, _ = apply_effects_tensor(xx.cpu().unsqueeze(0), sr, [["rate", "-vsL", str(self.sample_rate)]]) | |
| x_resampled.append(xx.to(device).mean(dim=0)) | |
| x = x_resampled | |
| if lengths is not None: | |
| lengths = lengths * sr // self.sample_rate | |
| x, auto_lengths = pad_packed_sequence(pack_sequence(x, enforce_sorted=False), batch_first=True) | |
| if lengths is None: | |
| lengths = auto_lengths | |
| elif x.ndim == 2: | |
| if sr is not None and sr != self.sample_rate: | |
| device = x.device | |
| x, _ = apply_effects_tensor(x.cpu(), sr, [["rate", "-vsL", str(self.sample_rate)]]) | |
| x = x.to(device) | |
| if lengths is not None: | |
| lengths = lengths * sr // self.sample_rate | |
| if lengths is None: | |
| lengths = torch.LongTensor([x.size(1) for _ in range(x.size(0))]) | |
| else: | |
| x = x.unsqueeze(0) | |
| if sr is not None and sr != self.sample_rate: | |
| device = x.device | |
| x, _ = apply_effects_tensor(x.cpu(), sr, [["rate", "-vsL", str(self.sample_rate)]]) | |
| x = x.to(device) | |
| if lengths is not None: | |
| lengths = lengths * sr // self.sample_rate | |
| if lengths is None: | |
| lengths = torch.LongTensor([x.size(1)]) | |
| lengths = lengths.to(x.device) // self.hop_length | |
| frames = x.unfold(-1, self.frame_length + 1, self.hop_length) | |
| if self.preemphasis > 0: | |
| if self.preemphasis_htk_flavor: | |
| f1 = frames[..., :1] * (1.0 - self.preemphasis) | |
| f2 = frames[..., 1:-1] - self.preemphasis * frames[..., :-2] | |
| frames = torch.cat((f1, f2), dim=-1) | |
| else: | |
| frames = frames[..., 1:] - self.preemphasis * frames[..., :-1] | |
| else: | |
| frames = frames[..., :-1] | |
| frames = self.window[None, None, :] * frames | |
| features = torch.fft.rfft(frames, n=self.fft_length).abs().transpose(-1, -2) | |
| features = self.mel_scale(features).clamp(min=self.mel_floor).log().transpose(-1, -2) | |
| if self.per_bin_mean is not None: | |
| features = features - self.per_bin_mean | |
| if self.per_bin_stddev is not None: | |
| features = features / self.per_bin_stddev | |
| # False if masked to match the implementation of Wav2Vec2FeatureExtractor | |
| mask = torch.arange(features.shape[1], device=features.device).unsqueeze(0) < lengths.unsqueeze(1) | |
| return self.encoder(features, mask, output_hidden_states, output_attentions) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment