Last active
April 23, 2025 09:38
-
-
Save eustlb/aed822f765e928b9612e01b0d8836d69 to your computer and use it in GitHub Desktop.
Reproducer for CSM Transformers integration
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| # TEST GREEDY FLOAT 32 | |
| # make sure to clone [email protected]:eustlb/csm.git and checkout compare-trfms | |
| import sys | |
| sys.path.insert(0, "./csm") | |
| from generator import load_csm_1b, Segment | |
| from huggingface_hub import hf_hub_download | |
| import torch | |
| import torchaudio | |
| # load original model | |
| model_path = hf_hub_download(repo_id="sesame/csm-1b", filename="ckpt.pt", revision="03ab46ff5cfdcc783cc76fcf9ea6fd0838503093") | |
| generator = load_csm_1b(model_path, "cuda:1") | |
| # Infer original model | |
| input_tokens, input_tokens_mask, output_tokens = generator.generate_return_tokens( | |
| text="The past is just a story we tell ourselves.", | |
| speaker=0, | |
| context=[], | |
| max_audio_length_ms=10_000, | |
| ) | |
| text_padding_token = 128002 | |
| codebook_padding_token = 2050 | |
| codebook_eos_token = 0 | |
| # Replace masked values with corresponding padding values in Transformers implem | |
| audio_frames_mask = input_tokens_mask[..., :32] | |
| text_frames_mask = input_tokens_mask[..., 32:] | |
| input_tokens_audio = input_tokens[..., :32].clone() | |
| input_tokens_text = input_tokens[..., 32:].clone() | |
| input_tokens_audio = input_tokens_audio.masked_fill(~audio_frames_mask, codebook_padding_token) | |
| input_tokens_text = input_tokens_text.masked_fill(~text_frames_mask, text_padding_token) | |
| input_tokens = torch.cat([input_tokens_audio, input_tokens_text], dim=-1) | |
| # Add the text padding token + eos frame to the expected output | |
| eos_frames = [codebook_eos_token] * 32 | |
| eos_frame = torch.tensor(eos_frames, device=input_tokens.device, dtype=torch.long)[None, None, :] | |
| expected_output_tokens = torch.cat([output_tokens, eos_frame], dim=1) | |
| expected_output_tokens_mask = torch.cat([input_tokens_mask, torch.ones_like(input_tokens_mask[..., :1])], dim=-1) | |
| torch.save(input_tokens, "input_tokens.pt") | |
| torch.save(expected_output_tokens, "expected_output_tokens.pt") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment