Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Select an option

  • Save eustlb/aed822f765e928b9612e01b0d8836d69 to your computer and use it in GitHub Desktop.

Select an option

Save eustlb/aed822f765e928b9612e01b0d8836d69 to your computer and use it in GitHub Desktop.
Reproducer for CSM Transformers integration
# TEST GREEDY FLOAT 32
# make sure to clone [email protected]:eustlb/csm.git and checkout compare-trfms
import sys
sys.path.insert(0, "./csm")
from generator import load_csm_1b, Segment
from huggingface_hub import hf_hub_download
import torch
import torchaudio
# load original model
model_path = hf_hub_download(repo_id="sesame/csm-1b", filename="ckpt.pt", revision="03ab46ff5cfdcc783cc76fcf9ea6fd0838503093")
generator = load_csm_1b(model_path, "cuda:1")
# Infer original model
input_tokens, input_tokens_mask, output_tokens = generator.generate_return_tokens(
text="The past is just a story we tell ourselves.",
speaker=0,
context=[],
max_audio_length_ms=10_000,
)
text_padding_token = 128002
codebook_padding_token = 2050
codebook_eos_token = 0
# Replace masked values with corresponding padding values in Transformers implem
audio_frames_mask = input_tokens_mask[..., :32]
text_frames_mask = input_tokens_mask[..., 32:]
input_tokens_audio = input_tokens[..., :32].clone()
input_tokens_text = input_tokens[..., 32:].clone()
input_tokens_audio = input_tokens_audio.masked_fill(~audio_frames_mask, codebook_padding_token)
input_tokens_text = input_tokens_text.masked_fill(~text_frames_mask, text_padding_token)
input_tokens = torch.cat([input_tokens_audio, input_tokens_text], dim=-1)
# Add the text padding token + eos frame to the expected output
eos_frames = [codebook_eos_token] * 32
eos_frame = torch.tensor(eos_frames, device=input_tokens.device, dtype=torch.long)[None, None, :]
expected_output_tokens = torch.cat([output_tokens, eos_frame], dim=1)
expected_output_tokens_mask = torch.cat([input_tokens_mask, torch.ones_like(input_tokens_mask[..., :1])], dim=-1)
torch.save(input_tokens, "input_tokens.pt")
torch.save(expected_output_tokens, "expected_output_tokens.pt")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment