Created
April 23, 2025 08:43
-
-
Save eustlb/0c94de002e1325abb61d32217f74c0f8 to your computer and use it in GitHub Desktop.
Reproducer for CSM Transformers integration
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| # TEST GREEDY FLOAT 32 | |
| # make sure to clone git@github.com:eustlb/csm.git and checkout compare-trfms | |
| import sys | |
| sys.path.insert(0, "./csm") | |
| from generator import load_csm_1b, Segment | |
| from datasets import load_dataset, Audio | |
| from huggingface_hub import hf_hub_download | |
| import torch | |
| import torchaudio | |
| # load original model | |
| model_path = hf_hub_download(repo_id="sesame/csm-1b", filename="ckpt.pt", revision="03ab46ff5cfdcc783cc76fcf9ea6fd0838503093") | |
| generator = load_csm_1b(model_path, "cuda:1") | |
| # prepare data | |
| ds = load_dataset("eustlb/dailytalk-dummy", split="train") | |
| ds = ds.cast_column("audio", Audio(sampling_rate=generator.sample_rate)) | |
| segments = [] | |
| for i in range(4): | |
| if i < len(ds): | |
| text = ds[i]["text"] | |
| audio = ds[i]["audio"]["array"] | |
| audio = torch.from_numpy(audio).to(torch.float32) | |
| speaker_id = ds[i]["speaker_id"] | |
| segments.append(Segment(text=text, speaker=speaker_id, audio=audio)) | |
| # Infer original model | |
| input_tokens, input_tokens_mask, output_tokens = generator.generate_return_tokens( | |
| text=ds[4]["text"], | |
| speaker=ds[4]["speaker_id"], | |
| context=segments, | |
| max_audio_length_ms=10_000, | |
| ) | |
| text_padding_token = 128002 | |
| codebook_padding_token = 2050 | |
| codebook_eos_token = 0 | |
| # Replace masked values with corresponding padding values in Transformers implem | |
| audio_frames_mask = input_tokens_mask[..., :32] | |
| text_frames_mask = input_tokens_mask[..., 32:] | |
| input_tokens_audio = input_tokens[..., :32].clone() | |
| input_tokens_text = input_tokens[..., 32:].clone() | |
| input_tokens_audio = input_tokens_audio.masked_fill(~audio_frames_mask, codebook_padding_token) | |
| input_tokens_text = input_tokens_text.masked_fill(~text_frames_mask, text_padding_token) | |
| input_tokens = torch.cat([input_tokens_audio, input_tokens_text], dim=-1) | |
| # Add the text padding token + eos frame to the expected output | |
| eos_frames = [codebook_eos_token] * 32 | |
| eos_frame = torch.tensor(eos_frames, device=input_tokens.device, dtype=torch.long)[None, None, :] | |
| expected_output_tokens = torch.cat([output_tokens, eos_frame], dim=1) | |
| expected_output_tokens_mask = torch.cat([input_tokens_mask, torch.ones_like(input_tokens_mask[..., :1])], dim=-1) | |
| torch.save(input_tokens, "input_tokens.pt") | |
| torch.save(expected_output_tokens, "expected_output_tokens.pt") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment