Skip to content

Instantly share code, notes, and snippets.

@imohitmayank
Last active November 8, 2025 03:54
Show Gist options
  • Save imohitmayank/1a9adc10a5721be943fdf7fd4f408568 to your computer and use it in GitHub Desktop.
Save imohitmayank/1a9adc10a5721be943fdf7fd4f408568 to your computer and use it in GitHub Desktop.
Training LLM-Based + Neural Codec TTS Models
# %%
# !pip install -q trl
# %%
import os
os.environ["WANDB_PROJECT"] = "gemma3-snac-finetuning"
os.environ["WANDB_LOG_MODEL"] = "checkpoint"
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
import logging
import wandb
import torch
import numpy as np
from typing import Dict, List, Optional, Tuple
from dataclasses import dataclass, field
from datasets import load_dataset, Dataset
from transformers import (
AutoTokenizer,
AutoModelForCausalLM,
TrainingArguments,
Trainer,
DataCollatorForLanguageModeling,
EarlyStoppingCallback,
TrainerCallback
)
from transformers.trainer_utils import get_last_checkpoint
from transformers.utils import check_min_version
# Uncomment if training on Colab and want to load HF and Wandb
# Start
from google.colab import userdata
from huggingface_hub.hf_api import HfFolder
# setup wandb token
wandb_token = userdata.get('wandb_token')
wandb.login(key=wandb_token)
# setup hf token
HfFolder.save_token(userdata.get('hf_token'))
# End
# %%
@dataclass
class ModelArguments:
"""Arguments pertaining to model configuration."""
model_name_or_path: str = field(
default="google/gemma-3-270m",
metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"}
)
cache_dir: Optional[str] = field(
default=None,
metadata={"help": "Where to store the pretrained models downloaded from huggingface.co"}
)
use_fast_tokenizer: bool = field(
default=True,
metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."}
)
torch_dtype: Optional[str] = field(
default=None,
metadata={"help": "Override the default `torch.dtype` and load the model under this dtype."}
)
@dataclass
class DataArguments:
"""Arguments pertaining to data configuration."""
dataset_name: str = field(
default="mohitmayank/elise_text_snac_codes",
metadata={"help": "The name of the dataset to use (via the datasets library)."}
)
max_seq_length: int = field(
default=1100,
metadata={"help": "The maximum total input sequence length after tokenization."}
)
preprocessing_num_workers: Optional[int] = field(
default=None,
metadata={"help": "The number of processes to use for the preprocessing."}
)
overwrite_cache: bool = field(
default=False,
metadata={"help": "Overwrite the cached training and evaluation sets"}
)
validation_split_percentage: int = field(
default=10,
metadata={"help": "Percentage of the dataset to use for validation"}
)
from transformers import TrainingArguments
@dataclass
class TrainingArgumentsCustom(TrainingArguments):
# output_dir: str = field(default="./gemma3-snac-finetuning")
output_dir="output_dir"
overwrite_output_dir: bool = field(default=False)
do_train: bool = field(default=True)
do_eval: bool = field(default=False)
# Logging & Verbosity
logging_steps: int = field(default=10) # log every step
logging_first_step: bool = field(default=True) # log the very first step
disable_tqdm: bool = field(default=False) # keep tqdm progress bars
report_to: str = field(default="wandb") # also goes to W&B
# run_name: Optional[str] = field(default="verbose_run")
# Evaluation / Saving
# eval_strategy: str = field(default="steps")
# eval_steps: int = field(default=5) # more frequent eval
# Training params (as you had them)
per_device_train_batch_size: int = field(default=2)
per_device_eval_batch_size: int = field(default=1)
gradient_accumulation_steps: int = field(default=4)
learning_rate: float = field(default=2e-4)
weight_decay: float = field(default=0.01)
num_train_epochs: float = field(default=20)
lr_scheduler_type: str = field(default="cosine")
warmup_ratio: float = field(default=0.1)
save_total_limit: int = field(default=5)
save_strategy: str = field(default="steps")
save_steps: int = field(default=10)
# load_best_model_at_end: bool = field(default=True)
# metric_for_best_model: str = field(default="eval_loss")
greater_is_better: bool = field(default=False)
# Dataloader
dataloader_num_workers: int = field(default=0)
dataloader_pin_memory: bool = field(default=False)
# Mixed precision
fp16: bool = field(default=False)
bf16: bool = field(default=False)
# Other
remove_unused_columns: bool = field(default=False)
include_inputs_for_metrics: bool = field(default=True)
prediction_loss_only: bool = field(default=False)
# Initialize configurations
print("Initializing configurations...")
model_args = ModelArguments()
data_args = DataArguments()
training_args = TrainingArgumentsCustom()
# %%
# SNAC codec configuration
SNAC_CONFIG = {
"num_layers": 3,
"codes_per_layer": 4096, # Each layer has codes 0-4095
"total_snac_tokens": 4096 * 3, # 12,288 additional tokens
"layer_names": ["snac_l1", "snac_l2", "snac_l3"],
"special_tokens": {
"audio_start": "<audio_start>",
"audio_end": "<audio_end>",
"layer_sep": "<layer_sep>",
"pad_token": "<snac_pad>"
}
}
print(f"SNAC Config: {SNAC_CONFIG}")
# %%
def load_snac_dataset(data_args: DataArguments) -> Tuple[Dataset, Dataset]:
"""Load and split the SNAC dataset."""
print(f"Loading dataset: {data_args.dataset_name}")
# Load the full dataset
dataset = load_dataset(data_args.dataset_name, split="train")
print(f"Loaded {len(dataset)} samples")
# add special tokens to the SNAC_CONFIG
special_words = []
for i in range(len(dataset)):
text = dataset[i]['text']
special_words.extend([word for word in text.split() if word.startswith('<') and word.endswith('>')])
# find the unique special words
unique_special_words = list(set(special_words))
SNAC_CONFIG["special_tokens"].update({word: word for word in unique_special_words})
train_dataset = dataset
eval_dataset = dataset
print(f"Train samples: {len(train_dataset)}")
print(f"Eval samples: {len(eval_dataset)}")
return train_dataset, eval_dataset
def convert_snac_codes_to_tokens(snac_codes: List[List[List[int]]], tokenizer: AutoTokenizer) -> List[str]:
"""Convert SNAC codes to token strings."""
tokens = [SNAC_CONFIG["special_tokens"]["audio_start"]]
# Process each layer
for layer_idx, layer_codes in enumerate(snac_codes):
if layer_idx > 0:
tokens.append(SNAC_CONFIG["special_tokens"]["layer_sep"])
layer_name = SNAC_CONFIG["layer_names"][layer_idx]
# Flatten the layer codes and convert to tokens
for batch_codes in layer_codes:
for code_val in batch_codes:
token = f"<{layer_name}_{code_val}>"
tokens.append(token)
tokens.append(SNAC_CONFIG["special_tokens"]["audio_end"])
return tokens
def convert_snac_codes_to_tokens_v2(snac_codes: List[List[List[int]]], tokenizer: AutoTokenizer) -> List[str]:
"""Convert SNAC codes to token strings."""
tokens = [SNAC_CONFIG["special_tokens"]["audio_start"]]
# Iterate over the codes
for i in range(len(snac_codes[0][0])):
# add the tokens from layer 1 (1 token)
tokens.append(f"<{SNAC_CONFIG['layer_names'][0]}_{snac_codes[0][0][i]}>")
# add the tokens from layer 2 (2 tokens)
tokens.extend([f"<{SNAC_CONFIG['layer_names'][1]}_{code_val}>" for code_val in snac_codes[1][0][(i*2):(i*2)+2]])
# add the tokens from layer 3 (4 tokens)
tokens.extend([f"<{SNAC_CONFIG['layer_names'][2]}_{code_val}>" for code_val in snac_codes[2][0][(i*4):(i*4)+4]])
tokens.append(SNAC_CONFIG["special_tokens"]["audio_end"])
return tokens
def preprocess_function(examples: Dict[str, List], tokenizer: AutoTokenizer, max_length: int):
"""Preprocess examples for training."""
batch_size = len(examples["text"])
processed_texts = []
for i in range(batch_size):
text = examples["text"][i]
snac_codes = examples["snac_codes"][i]
# Convert SNAC codes to tokens
snac_tokens = convert_snac_codes_to_tokens_v2(snac_codes, tokenizer)
snac_text = "".join(snac_tokens)
# Create input in format: "Text: {text} Audio: {snac_tokens}"
input_text = f"{text} {snac_text}"
processed_texts.append(input_text)
# Tokenize the processed texts
tokenized = tokenizer(
processed_texts,
truncation=True,
padding='max_length', # Pad sequences to max_length
max_length=max_length,
return_tensors=None
)
# Set labels equal to input_ids for causal language modeling
tokenized["labels"] = tokenized["input_ids"].copy()
return tokenized
def prepare_datasets(train_dataset: Dataset, eval_dataset: Dataset, tokenizer: AutoTokenizer, data_args: DataArguments):
"""Prepare datasets for training."""
print("Preprocessing datasets...")
# Preprocess training dataset
train_dataset = train_dataset.map(
lambda examples: preprocess_function(examples, tokenizer, data_args.max_seq_length),
batched=True,
num_proc=data_args.preprocessing_num_workers,
remove_columns=train_dataset.column_names,
load_from_cache_file=not data_args.overwrite_cache,
desc="Preprocessing train dataset"
)
# Preprocess evaluation dataset
# eval_dataset = eval_dataset.map(
# lambda examples: preprocess_function(examples, tokenizer, data_args.max_seq_length),
# batched=True,
# num_proc=data_args.preprocessing_num_workers,
# remove_columns=eval_dataset.column_names,
# load_from_cache_file=not data_args.overwrite_cache,
# desc="Preprocessing eval dataset"
# )
# for testing let's keep eval = train
eval_dataset = train_dataset
print(f"Preprocessed train dataset: {len(train_dataset)} samples")
print(f"Preprocessed eval dataset: {len(eval_dataset)} samples")
return train_dataset, eval_dataset
# %%
def setup_model_and_tokenizer(model_args: ModelArguments) -> Tuple[AutoModelForCausalLM, AutoTokenizer]:
"""Setup and extend the Gemma model and tokenizer for SNAC tokens."""
print(f"Loading model and tokenizer from {model_args.model_name_or_path}")
# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(
model_args.model_name_or_path,
cache_dir=model_args.cache_dir,
use_fast=model_args.use_fast_tokenizer,
trust_remote_code=True
)
# Add padding token if not present
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
tokenizer.pad_token_id = tokenizer.eos_token_id
# Ensure padding is on the left for causal language modeling
tokenizer.padding_side = "right"
# Store original vocab size
original_vocab_size = len(tokenizer)
print(f"Original vocabulary size: {original_vocab_size}")
# Add SNAC special tokens
special_tokens_list = list(SNAC_CONFIG["special_tokens"].values())
num_added_special = tokenizer.add_special_tokens({
"additional_special_tokens": special_tokens_list
})
print(f"Added {num_added_special} special tokens")
# Add SNAC code tokens
snac_tokens = []
for layer_idx in range(SNAC_CONFIG["num_layers"]):
layer_name = SNAC_CONFIG["layer_names"][layer_idx]
for code_val in range(SNAC_CONFIG["codes_per_layer"]):
token = f"<{layer_name}_{code_val}>"
snac_tokens.append(token)
# Add all SNAC tokens to the tokenizer
num_added_snac = tokenizer.add_tokens(snac_tokens)
print(f"Added {num_added_snac} SNAC code tokens")
# Verify total tokens added
total_added = num_added_special + num_added_snac
expected_total = len(special_tokens_list) + SNAC_CONFIG["total_snac_tokens"]
assert total_added == expected_total, f"Expected {expected_total} tokens, got {total_added}"
print(f"New vocabulary size: {len(tokenizer)}")
# Load model
torch_dtype = getattr(torch, model_args.torch_dtype) if model_args.torch_dtype else None
model = AutoModelForCausalLM.from_pretrained(
model_args.model_name_or_path,
cache_dir=model_args.cache_dir,
torch_dtype=torch.bfloat16,
trust_remote_code=True,
attn_implementation='eager'
)
# Resize model embeddings to accommodate new tokens
print(f"Resizing model embeddings to {len(tokenizer)} tokens")
model.resize_token_embeddings(len(tokenizer))
print(f"Resized model embeddings to {len(tokenizer)} tokens")
print("Initialized new token embeddings")
return model, tokenizer, original_vocab_size
# %% [markdown]
# ## Load Dataset and Model
# %%
print("Starting Gemma 3 SNAC finetuning with standard HF Trainer...")
# Setup W&B logging
# setup_wandb(training_args, model_args, data_args)
# Load datasets
train_dataset, eval_dataset = load_snac_dataset(data_args)
# Setup model and tokenizer
model, tokenizer, original_vocab_size = setup_model_and_tokenizer(model_args)
print("Model Dtype: ", model.dtype)
# Prepare datasets
train_dataset, eval_dataset = prepare_datasets(train_dataset, eval_dataset, tokenizer, data_args)
# Data collator for language modeling
data_collator = DataCollatorForLanguageModeling(
tokenizer=tokenizer,
mlm=False, # We're doing causal LM, not masked LM
pad_to_multiple_of=8 if training_args.fp16 or training_args.bf16 else None,
return_tensors="pt"
)
# %%
from trl import SFTTrainer
from transformers import DataCollatorForLanguageModeling
# Initialize the SFTTrainer
trainer = SFTTrainer(
model=model,
train_dataset=train_dataset,
args=training_args,
data_collator=data_collator,
)
print("SFTTrainer initialized.")
# %%
# Start training
print("Starting model training...")
trainer.train()
print("Training finished.")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment