Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Select an option

  • Save eustlb/bcc532b53161bc31da3d66cb07ae193f to your computer and use it in GitHub Desktop.

Select an option

Save eustlb/bcc532b53161bc31da3d66cb07ae193f to your computer and use it in GitHub Desktop.
Reproducer for CSM Transformers integration
# TEST GREEDY FLOAT 32
# make sure to clone [email protected]:eustlb/csm.git and checkout compare-trfms
import sys
sys.path.insert(0, "./csm")
from generator import load_csm_1b, Segment
from datasets import load_dataset, Audio
from huggingface_hub import hf_hub_download
import torch
import torchaudio
# load original model
model_path = hf_hub_download(repo_id="sesame/csm-1b", filename="ckpt.pt", revision="03ab46ff5cfdcc783cc76fcf9ea6fd0838503093")
generator = load_csm_1b(model_path, "cuda:1")
# prepare data
ds = load_dataset("eustlb/dailytalk-dummy", split="train")
ds = ds.cast_column("audio", Audio(sampling_rate=generator.sample_rate))
# ============== batch index 0 ==============
text = ds[0]["text"]
audio = ds[0]["audio"]["array"]
audio = torch.from_numpy(audio).to(torch.float32)
segments = [
Segment(text=text, speaker=ds[0]["speaker_id"], audio=audio)
]
# Infer original model
input_tokens, input_tokens_mask, output_tokens = generator.generate_return_tokens(
text=ds[1]["text"],
speaker=ds[1]["speaker_id"],
context=segments,
max_audio_length_ms=10_000,
)
text_padding_token = 128002
codebook_padding_token = 2050
codebook_eos_token = 0
# Replace masked values with corresponding padding values in Transformers implem
audio_frames_mask = input_tokens_mask[..., :32]
text_frames_mask = input_tokens_mask[..., 32:]
input_tokens_audio = input_tokens[..., :32].clone()
input_tokens_text = input_tokens[..., 32:].clone()
input_tokens_audio = input_tokens_audio.masked_fill(~audio_frames_mask, codebook_padding_token)
input_tokens_text = input_tokens_text.masked_fill(~text_frames_mask, text_padding_token)
input_tokens = torch.cat([input_tokens_audio, input_tokens_text], dim=-1)
# Add the text padding token + eos frame to the expected output
eos_frames = [codebook_eos_token] * 32
eos_frame = torch.tensor(eos_frames, device=input_tokens.device, dtype=torch.long)[None, None, :]
expected_output_tokens = torch.cat([output_tokens, eos_frame], dim=1)
expected_output_tokens_mask = torch.cat([input_tokens_mask, torch.ones_like(input_tokens_mask[..., :1])], dim=-1)
torch.save(input_tokens, "input_tokens_0.pt")
torch.save(expected_output_tokens, "expected_output_tokens_0.pt")
# ============== batch index 1 ==============
# Infer original model
input_tokens, input_tokens_mask, output_tokens = generator.generate_return_tokens(
text=ds[0]["text"],
speaker=ds[0]["speaker_id"],
context=[],
max_audio_length_ms=10_000,
)
text_padding_token = 128002
codebook_padding_token = 2050
codebook_eos_token = 0
# Replace masked values with corresponding padding values in Transformers implem
audio_frames_mask = input_tokens_mask[..., :32]
text_frames_mask = input_tokens_mask[..., 32:]
input_tokens_audio = input_tokens[..., :32].clone()
input_tokens_text = input_tokens[..., 32:].clone()
input_tokens_audio = input_tokens_audio.masked_fill(~audio_frames_mask, codebook_padding_token)
input_tokens_text = input_tokens_text.masked_fill(~text_frames_mask, text_padding_token)
input_tokens = torch.cat([input_tokens_audio, input_tokens_text], dim=-1)
# Add the text padding token + eos frame to the expected output
eos_frames = [codebook_eos_token] * 32
eos_frame = torch.tensor(eos_frames, device=input_tokens.device, dtype=torch.long)[None, None, :]
expected_output_tokens = torch.cat([output_tokens, eos_frame], dim=1)
expected_output_tokens_mask = torch.cat([input_tokens_mask, torch.ones_like(input_tokens_mask[..., :1])], dim=-1)
torch.save(input_tokens, "input_tokens_1.pt")
torch.save(expected_output_tokens, "expected_output_tokens_1.pt")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment