Last active
November 8, 2025 03:54
-
-
Save imohitmayank/1a9adc10a5721be943fdf7fd4f408568 to your computer and use it in GitHub Desktop.
Training LLM-Based + Neural Codec TTS Models
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| # %% | |
| # !pip install -q trl | |
| # %% | |
| import os | |
| os.environ["WANDB_PROJECT"] = "gemma3-snac-finetuning" | |
| os.environ["WANDB_LOG_MODEL"] = "checkpoint" | |
| os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" | |
| import logging | |
| import wandb | |
| import torch | |
| import numpy as np | |
| from typing import Dict, List, Optional, Tuple | |
| from dataclasses import dataclass, field | |
| from datasets import load_dataset, Dataset | |
| from transformers import ( | |
| AutoTokenizer, | |
| AutoModelForCausalLM, | |
| TrainingArguments, | |
| Trainer, | |
| DataCollatorForLanguageModeling, | |
| EarlyStoppingCallback, | |
| TrainerCallback | |
| ) | |
| from transformers.trainer_utils import get_last_checkpoint | |
| from transformers.utils import check_min_version | |
| # Uncomment if training on Colab and want to load HF and Wandb | |
| # Start | |
| from google.colab import userdata | |
| from huggingface_hub.hf_api import HfFolder | |
| # setup wandb token | |
| wandb_token = userdata.get('wandb_token') | |
| wandb.login(key=wandb_token) | |
| # setup hf token | |
| HfFolder.save_token(userdata.get('hf_token')) | |
| # End | |
| # %% | |
| @dataclass | |
| class ModelArguments: | |
| """Arguments pertaining to model configuration.""" | |
| model_name_or_path: str = field( | |
| default="google/gemma-3-270m", | |
| metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"} | |
| ) | |
| cache_dir: Optional[str] = field( | |
| default=None, | |
| metadata={"help": "Where to store the pretrained models downloaded from huggingface.co"} | |
| ) | |
| use_fast_tokenizer: bool = field( | |
| default=True, | |
| metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."} | |
| ) | |
| torch_dtype: Optional[str] = field( | |
| default=None, | |
| metadata={"help": "Override the default `torch.dtype` and load the model under this dtype."} | |
| ) | |
| @dataclass | |
| class DataArguments: | |
| """Arguments pertaining to data configuration.""" | |
| dataset_name: str = field( | |
| default="mohitmayank/elise_text_snac_codes", | |
| metadata={"help": "The name of the dataset to use (via the datasets library)."} | |
| ) | |
| max_seq_length: int = field( | |
| default=1100, | |
| metadata={"help": "The maximum total input sequence length after tokenization."} | |
| ) | |
| preprocessing_num_workers: Optional[int] = field( | |
| default=None, | |
| metadata={"help": "The number of processes to use for the preprocessing."} | |
| ) | |
| overwrite_cache: bool = field( | |
| default=False, | |
| metadata={"help": "Overwrite the cached training and evaluation sets"} | |
| ) | |
| validation_split_percentage: int = field( | |
| default=10, | |
| metadata={"help": "Percentage of the dataset to use for validation"} | |
| ) | |
| from transformers import TrainingArguments | |
| @dataclass | |
| class TrainingArgumentsCustom(TrainingArguments): | |
| # output_dir: str = field(default="./gemma3-snac-finetuning") | |
| output_dir="output_dir" | |
| overwrite_output_dir: bool = field(default=False) | |
| do_train: bool = field(default=True) | |
| do_eval: bool = field(default=False) | |
| # Logging & Verbosity | |
| logging_steps: int = field(default=10) # log every step | |
| logging_first_step: bool = field(default=True) # log the very first step | |
| disable_tqdm: bool = field(default=False) # keep tqdm progress bars | |
| report_to: str = field(default="wandb") # also goes to W&B | |
| # run_name: Optional[str] = field(default="verbose_run") | |
| # Evaluation / Saving | |
| # eval_strategy: str = field(default="steps") | |
| # eval_steps: int = field(default=5) # more frequent eval | |
| # Training params (as you had them) | |
| per_device_train_batch_size: int = field(default=2) | |
| per_device_eval_batch_size: int = field(default=1) | |
| gradient_accumulation_steps: int = field(default=4) | |
| learning_rate: float = field(default=2e-4) | |
| weight_decay: float = field(default=0.01) | |
| num_train_epochs: float = field(default=20) | |
| lr_scheduler_type: str = field(default="cosine") | |
| warmup_ratio: float = field(default=0.1) | |
| save_total_limit: int = field(default=5) | |
| save_strategy: str = field(default="steps") | |
| save_steps: int = field(default=10) | |
| # load_best_model_at_end: bool = field(default=True) | |
| # metric_for_best_model: str = field(default="eval_loss") | |
| greater_is_better: bool = field(default=False) | |
| # Dataloader | |
| dataloader_num_workers: int = field(default=0) | |
| dataloader_pin_memory: bool = field(default=False) | |
| # Mixed precision | |
| fp16: bool = field(default=False) | |
| bf16: bool = field(default=False) | |
| # Other | |
| remove_unused_columns: bool = field(default=False) | |
| include_inputs_for_metrics: bool = field(default=True) | |
| prediction_loss_only: bool = field(default=False) | |
| # Initialize configurations | |
| print("Initializing configurations...") | |
| model_args = ModelArguments() | |
| data_args = DataArguments() | |
| training_args = TrainingArgumentsCustom() | |
| # %% | |
| # SNAC codec configuration | |
| SNAC_CONFIG = { | |
| "num_layers": 3, | |
| "codes_per_layer": 4096, # Each layer has codes 0-4095 | |
| "total_snac_tokens": 4096 * 3, # 12,288 additional tokens | |
| "layer_names": ["snac_l1", "snac_l2", "snac_l3"], | |
| "special_tokens": { | |
| "audio_start": "<audio_start>", | |
| "audio_end": "<audio_end>", | |
| "layer_sep": "<layer_sep>", | |
| "pad_token": "<snac_pad>" | |
| } | |
| } | |
| print(f"SNAC Config: {SNAC_CONFIG}") | |
| # %% | |
| def load_snac_dataset(data_args: DataArguments) -> Tuple[Dataset, Dataset]: | |
| """Load and split the SNAC dataset.""" | |
| print(f"Loading dataset: {data_args.dataset_name}") | |
| # Load the full dataset | |
| dataset = load_dataset(data_args.dataset_name, split="train") | |
| print(f"Loaded {len(dataset)} samples") | |
| # add special tokens to the SNAC_CONFIG | |
| special_words = [] | |
| for i in range(len(dataset)): | |
| text = dataset[i]['text'] | |
| special_words.extend([word for word in text.split() if word.startswith('<') and word.endswith('>')]) | |
| # find the unique special words | |
| unique_special_words = list(set(special_words)) | |
| SNAC_CONFIG["special_tokens"].update({word: word for word in unique_special_words}) | |
| train_dataset = dataset | |
| eval_dataset = dataset | |
| print(f"Train samples: {len(train_dataset)}") | |
| print(f"Eval samples: {len(eval_dataset)}") | |
| return train_dataset, eval_dataset | |
| def convert_snac_codes_to_tokens(snac_codes: List[List[List[int]]], tokenizer: AutoTokenizer) -> List[str]: | |
| """Convert SNAC codes to token strings.""" | |
| tokens = [SNAC_CONFIG["special_tokens"]["audio_start"]] | |
| # Process each layer | |
| for layer_idx, layer_codes in enumerate(snac_codes): | |
| if layer_idx > 0: | |
| tokens.append(SNAC_CONFIG["special_tokens"]["layer_sep"]) | |
| layer_name = SNAC_CONFIG["layer_names"][layer_idx] | |
| # Flatten the layer codes and convert to tokens | |
| for batch_codes in layer_codes: | |
| for code_val in batch_codes: | |
| token = f"<{layer_name}_{code_val}>" | |
| tokens.append(token) | |
| tokens.append(SNAC_CONFIG["special_tokens"]["audio_end"]) | |
| return tokens | |
| def convert_snac_codes_to_tokens_v2(snac_codes: List[List[List[int]]], tokenizer: AutoTokenizer) -> List[str]: | |
| """Convert SNAC codes to token strings.""" | |
| tokens = [SNAC_CONFIG["special_tokens"]["audio_start"]] | |
| # Iterate over the codes | |
| for i in range(len(snac_codes[0][0])): | |
| # add the tokens from layer 1 (1 token) | |
| tokens.append(f"<{SNAC_CONFIG['layer_names'][0]}_{snac_codes[0][0][i]}>") | |
| # add the tokens from layer 2 (2 tokens) | |
| tokens.extend([f"<{SNAC_CONFIG['layer_names'][1]}_{code_val}>" for code_val in snac_codes[1][0][(i*2):(i*2)+2]]) | |
| # add the tokens from layer 3 (4 tokens) | |
| tokens.extend([f"<{SNAC_CONFIG['layer_names'][2]}_{code_val}>" for code_val in snac_codes[2][0][(i*4):(i*4)+4]]) | |
| tokens.append(SNAC_CONFIG["special_tokens"]["audio_end"]) | |
| return tokens | |
| def preprocess_function(examples: Dict[str, List], tokenizer: AutoTokenizer, max_length: int): | |
| """Preprocess examples for training.""" | |
| batch_size = len(examples["text"]) | |
| processed_texts = [] | |
| for i in range(batch_size): | |
| text = examples["text"][i] | |
| snac_codes = examples["snac_codes"][i] | |
| # Convert SNAC codes to tokens | |
| snac_tokens = convert_snac_codes_to_tokens_v2(snac_codes, tokenizer) | |
| snac_text = "".join(snac_tokens) | |
| # Create input in format: "Text: {text} Audio: {snac_tokens}" | |
| input_text = f"{text} {snac_text}" | |
| processed_texts.append(input_text) | |
| # Tokenize the processed texts | |
| tokenized = tokenizer( | |
| processed_texts, | |
| truncation=True, | |
| padding='max_length', # Pad sequences to max_length | |
| max_length=max_length, | |
| return_tensors=None | |
| ) | |
| # Set labels equal to input_ids for causal language modeling | |
| tokenized["labels"] = tokenized["input_ids"].copy() | |
| return tokenized | |
| def prepare_datasets(train_dataset: Dataset, eval_dataset: Dataset, tokenizer: AutoTokenizer, data_args: DataArguments): | |
| """Prepare datasets for training.""" | |
| print("Preprocessing datasets...") | |
| # Preprocess training dataset | |
| train_dataset = train_dataset.map( | |
| lambda examples: preprocess_function(examples, tokenizer, data_args.max_seq_length), | |
| batched=True, | |
| num_proc=data_args.preprocessing_num_workers, | |
| remove_columns=train_dataset.column_names, | |
| load_from_cache_file=not data_args.overwrite_cache, | |
| desc="Preprocessing train dataset" | |
| ) | |
| # Preprocess evaluation dataset | |
| # eval_dataset = eval_dataset.map( | |
| # lambda examples: preprocess_function(examples, tokenizer, data_args.max_seq_length), | |
| # batched=True, | |
| # num_proc=data_args.preprocessing_num_workers, | |
| # remove_columns=eval_dataset.column_names, | |
| # load_from_cache_file=not data_args.overwrite_cache, | |
| # desc="Preprocessing eval dataset" | |
| # ) | |
| # for testing let's keep eval = train | |
| eval_dataset = train_dataset | |
| print(f"Preprocessed train dataset: {len(train_dataset)} samples") | |
| print(f"Preprocessed eval dataset: {len(eval_dataset)} samples") | |
| return train_dataset, eval_dataset | |
| # %% | |
| def setup_model_and_tokenizer(model_args: ModelArguments) -> Tuple[AutoModelForCausalLM, AutoTokenizer]: | |
| """Setup and extend the Gemma model and tokenizer for SNAC tokens.""" | |
| print(f"Loading model and tokenizer from {model_args.model_name_or_path}") | |
| # Load tokenizer | |
| tokenizer = AutoTokenizer.from_pretrained( | |
| model_args.model_name_or_path, | |
| cache_dir=model_args.cache_dir, | |
| use_fast=model_args.use_fast_tokenizer, | |
| trust_remote_code=True | |
| ) | |
| # Add padding token if not present | |
| if tokenizer.pad_token is None: | |
| tokenizer.pad_token = tokenizer.eos_token | |
| tokenizer.pad_token_id = tokenizer.eos_token_id | |
| # Ensure padding is on the left for causal language modeling | |
| tokenizer.padding_side = "right" | |
| # Store original vocab size | |
| original_vocab_size = len(tokenizer) | |
| print(f"Original vocabulary size: {original_vocab_size}") | |
| # Add SNAC special tokens | |
| special_tokens_list = list(SNAC_CONFIG["special_tokens"].values()) | |
| num_added_special = tokenizer.add_special_tokens({ | |
| "additional_special_tokens": special_tokens_list | |
| }) | |
| print(f"Added {num_added_special} special tokens") | |
| # Add SNAC code tokens | |
| snac_tokens = [] | |
| for layer_idx in range(SNAC_CONFIG["num_layers"]): | |
| layer_name = SNAC_CONFIG["layer_names"][layer_idx] | |
| for code_val in range(SNAC_CONFIG["codes_per_layer"]): | |
| token = f"<{layer_name}_{code_val}>" | |
| snac_tokens.append(token) | |
| # Add all SNAC tokens to the tokenizer | |
| num_added_snac = tokenizer.add_tokens(snac_tokens) | |
| print(f"Added {num_added_snac} SNAC code tokens") | |
| # Verify total tokens added | |
| total_added = num_added_special + num_added_snac | |
| expected_total = len(special_tokens_list) + SNAC_CONFIG["total_snac_tokens"] | |
| assert total_added == expected_total, f"Expected {expected_total} tokens, got {total_added}" | |
| print(f"New vocabulary size: {len(tokenizer)}") | |
| # Load model | |
| torch_dtype = getattr(torch, model_args.torch_dtype) if model_args.torch_dtype else None | |
| model = AutoModelForCausalLM.from_pretrained( | |
| model_args.model_name_or_path, | |
| cache_dir=model_args.cache_dir, | |
| torch_dtype=torch.bfloat16, | |
| trust_remote_code=True, | |
| attn_implementation='eager' | |
| ) | |
| # Resize model embeddings to accommodate new tokens | |
| print(f"Resizing model embeddings to {len(tokenizer)} tokens") | |
| model.resize_token_embeddings(len(tokenizer)) | |
| print(f"Resized model embeddings to {len(tokenizer)} tokens") | |
| print("Initialized new token embeddings") | |
| return model, tokenizer, original_vocab_size | |
| # %% [markdown] | |
| # ## Load Dataset and Model | |
| # %% | |
| print("Starting Gemma 3 SNAC finetuning with standard HF Trainer...") | |
| # Setup W&B logging | |
| # setup_wandb(training_args, model_args, data_args) | |
| # Load datasets | |
| train_dataset, eval_dataset = load_snac_dataset(data_args) | |
| # Setup model and tokenizer | |
| model, tokenizer, original_vocab_size = setup_model_and_tokenizer(model_args) | |
| print("Model Dtype: ", model.dtype) | |
| # Prepare datasets | |
| train_dataset, eval_dataset = prepare_datasets(train_dataset, eval_dataset, tokenizer, data_args) | |
| # Data collator for language modeling | |
| data_collator = DataCollatorForLanguageModeling( | |
| tokenizer=tokenizer, | |
| mlm=False, # We're doing causal LM, not masked LM | |
| pad_to_multiple_of=8 if training_args.fp16 or training_args.bf16 else None, | |
| return_tensors="pt" | |
| ) | |
| # %% | |
| from trl import SFTTrainer | |
| from transformers import DataCollatorForLanguageModeling | |
| # Initialize the SFTTrainer | |
| trainer = SFTTrainer( | |
| model=model, | |
| train_dataset=train_dataset, | |
| args=training_args, | |
| data_collator=data_collator, | |
| ) | |
| print("SFTTrainer initialized.") | |
| # %% | |
| # Start training | |
| print("Starting model training...") | |
| trainer.train() | |
| print("Training finished.") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment