Skip to content

Instantly share code, notes, and snippets.

@eustlb
Last active May 19, 2025 14:49
Show Gist options
  • Select an option

  • Save eustlb/85888dd85f567dad83a4b0288a8eb720 to your computer and use it in GitHub Desktop.

Select an option

Save eustlb/85888dd85f567dad83a4b0288a8eb720 to your computer and use it in GitHub Desktop.
from datasets import load_dataset, Audio
from transformers import (
CsmForConditionalGeneration,
TrainingArguments,
CsmProcessor,
Trainer
)
processor = CsmProcessor.from_pretrained("eustlb/csm-1b")
model = CsmForConditionalGeneration.from_pretrained("eustlb/csm-1b")
model.train()
model.codec_model.eval()
ds = load_dataset("eustlb/dailytalk-conversations-grouped", split="train")
ds = ds.cast_column("audio", Audio(sampling_rate=processor.feature_extractor.sampling_rate))
def data_collator(samples):
conversations = []
for sample in samples:
concatenated_audio_array = sample["audio"]["array"]
audio = [concatenated_audio_array[s: e] for s, e in sample["audio_cut_idxs"]]
conversation = []
for speaker_id, text, audio in zip(sample["speaker_ids"], sample["texts"], audio):
conversation.append({
"role": f"{speaker_id}",
"content": [
{"type": "text", "text": text},
{"type": "audio", "audio": audio}
]
})
conversations.append(conversation)
inputs = processor.apply_chat_template(
conversations,
tokenize=True,
return_dict=True,
output_labels=True,
)
return inputs
training_args = TrainingArguments(
"csm-1b-finetuned",
remove_unused_columns=False,
gradient_checkpointing=True,
dataloader_num_workers=4,
dataloader_pin_memory=True,
)
trainer = Trainer(
model,
training_args,
train_dataset=ds,
data_collator=data_collator,
)
trainer.train()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment