Created
April 13, 2026 13:39
-
-
Save duarteocarmo/7fb6dfa9e623dfbcbb8b97a9e867f904 to your computer and use it in GitHub Desktop.
Qwen3-TTS finetuning sft_12hz.py with 0.6B + 1.7B support; otherwise identical to upstream
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| # coding=utf-8 | |
| # Copyright 2026 The Alibaba Qwen team. | |
| # SPDX-License-Identifier: Apache-2.0 | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| import argparse | |
| import json | |
| import os | |
| import shutil | |
| import torch | |
| from accelerate import Accelerator | |
| from dataset import TTSDataset | |
| from qwen_tts.inference.qwen3_tts_model import Qwen3TTSModel | |
| from safetensors.torch import save_file | |
| from torch.optim import AdamW | |
| from torch.utils.data import DataLoader | |
| from transformers import AutoConfig | |
| target_speaker_embedding = None | |
| def train(): | |
| global target_speaker_embedding | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("--init_model_path", type=str, default="Qwen/Qwen3-TTS-12Hz-1.7B-Base") | |
| parser.add_argument("--output_model_path", type=str, default="output") | |
| parser.add_argument("--train_jsonl", type=str, required=True) | |
| parser.add_argument("--batch_size", type=int, default=2) | |
| parser.add_argument("--lr", type=float, default=2e-5) | |
| parser.add_argument("--num_epochs", type=int, default=3) | |
| parser.add_argument("--speaker_name", type=str, default="speaker_test") | |
| args = parser.parse_args() | |
| accelerator = Accelerator(gradient_accumulation_steps=4, mixed_precision="bf16", log_with="tensorboard") | |
| MODEL_PATH = args.init_model_path | |
| qwen3tts = Qwen3TTSModel.from_pretrained( | |
| MODEL_PATH, | |
| torch_dtype=torch.bfloat16, | |
| attn_implementation="flash_attention_2", | |
| ) | |
| config = AutoConfig.from_pretrained(MODEL_PATH) | |
| train_data = open(args.train_jsonl).readlines() | |
| train_data = [json.loads(line) for line in train_data] | |
| dataset = TTSDataset(train_data, qwen3tts.processor, config) | |
| train_dataloader = DataLoader(dataset, batch_size=args.batch_size, shuffle=True, collate_fn=dataset.collate_fn) | |
| optimizer = AdamW(qwen3tts.model.parameters(), lr=args.lr, weight_decay=0.01) | |
| model, optimizer, train_dataloader = accelerator.prepare( | |
| qwen3tts.model, optimizer, train_dataloader | |
| ) | |
| num_epochs = args.num_epochs | |
| model.train() | |
| for epoch in range(num_epochs): | |
| for step, batch in enumerate(train_dataloader): | |
| with accelerator.accumulate(model): | |
| input_ids = batch['input_ids'] | |
| codec_ids = batch['codec_ids'] | |
| ref_mels = batch['ref_mels'] | |
| text_embedding_mask = batch['text_embedding_mask'] | |
| codec_embedding_mask = batch['codec_embedding_mask'] | |
| attention_mask = batch['attention_mask'] | |
| codec_0_labels = batch['codec_0_labels'] | |
| codec_mask = batch['codec_mask'] | |
| speaker_embedding = model.speaker_encoder(ref_mels.to(model.device).to(model.dtype)).detach() | |
| if target_speaker_embedding is None: | |
| target_speaker_embedding = speaker_embedding | |
| input_text_ids = input_ids[:, :, 0] | |
| input_codec_ids = input_ids[:, :, 1] | |
| input_text_embedding = model.talker.model.text_embedding(input_text_ids) | |
| input_codec_embedding = model.talker.model.codec_embedding(input_codec_ids) * codec_embedding_mask | |
| input_codec_embedding[:, 6, :] = speaker_embedding | |
| if input_text_embedding.shape[-1] != input_codec_embedding.shape[-1]: | |
| input_text_embedding = model.talker.text_projection(input_text_embedding) | |
| input_text_embedding = input_text_embedding * text_embedding_mask | |
| input_embeddings = input_text_embedding + input_codec_embedding | |
| for i in range(1, 16): | |
| codec_i_embedding = model.talker.code_predictor.get_input_embeddings()[i - 1](codec_ids[:, :, i]) | |
| codec_i_embedding = codec_i_embedding * codec_mask.unsqueeze(-1) | |
| input_embeddings = input_embeddings + codec_i_embedding | |
| outputs = model.talker( | |
| inputs_embeds=input_embeddings[:, :-1, :], | |
| attention_mask=attention_mask[:, :-1], | |
| labels=codec_0_labels[:, 1:], | |
| output_hidden_states=True | |
| ) | |
| hidden_states = outputs.hidden_states[0][-1] | |
| talker_hidden_states = hidden_states[codec_mask[:, :-1]] | |
| talker_codec_ids = codec_ids[codec_mask] | |
| sub_talker_logits, sub_talker_loss = model.talker.forward_sub_talker_finetune(talker_codec_ids, talker_hidden_states) | |
| loss = outputs.loss + 0.3 * sub_talker_loss | |
| accelerator.backward(loss) | |
| if accelerator.sync_gradients: | |
| accelerator.clip_grad_norm_(model.parameters(), 1.0) | |
| optimizer.step() | |
| optimizer.zero_grad() | |
| if step % 10 == 0: | |
| accelerator.print(f"Epoch {epoch} | Step {step} | Loss: {loss.item():.4f}") | |
| if accelerator.is_main_process: | |
| output_dir = os.path.join(args.output_model_path, f"checkpoint-epoch-{epoch}") | |
| shutil.copytree(MODEL_PATH, output_dir, dirs_exist_ok=True) | |
| input_config_file = os.path.join(MODEL_PATH, "config.json") | |
| output_config_file = os.path.join(output_dir, "config.json") | |
| with open(input_config_file, 'r', encoding='utf-8') as f: | |
| config_dict = json.load(f) | |
| config_dict["tts_model_type"] = "custom_voice" | |
| talker_config = config_dict.get("talker_config", {}) | |
| talker_config["spk_id"] = { | |
| args.speaker_name: 3000 | |
| } | |
| talker_config["spk_is_dialect"] = { | |
| args.speaker_name: False | |
| } | |
| config_dict["talker_config"] = talker_config | |
| with open(output_config_file, 'w', encoding='utf-8') as f: | |
| json.dump(config_dict, f, indent=2, ensure_ascii=False) | |
| unwrapped_model = accelerator.unwrap_model(model) | |
| state_dict = {k: v.detach().to("cpu") for k, v in unwrapped_model.state_dict().items()} | |
| drop_prefix = "speaker_encoder" | |
| keys_to_drop = [k for k in state_dict.keys() if k.startswith(drop_prefix)] | |
| for k in keys_to_drop: | |
| del state_dict[k] | |
| weight = state_dict['talker.model.codec_embedding.weight'] | |
| state_dict['talker.model.codec_embedding.weight'][3000] = target_speaker_embedding[0].detach().to(weight.device).to(weight.dtype) | |
| save_path = os.path.join(output_dir, "model.safetensors") | |
| save_file(state_dict, save_path) | |
| if __name__ == "__main__": | |
| train() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment