Skip to content

Instantly share code, notes, and snippets.

@mizoru
Last active July 15, 2025 15:34
Show Gist options
  • Save mizoru/90644a62b7646f763e89ca2182763809 to your computer and use it in GitHub Desktop.
Save mizoru/90644a62b7646f763e89ca2182763809 to your computer and use it in GitHub Desktop.
from fish_speech.models.text2semantic.inference import *
text: str = ["Что? Он никогда не работал?", "Никогда бы о таком не подумал раньше"]
prompt_text: Optional[list[str]] = ["Что он никогда вот не работал, но стоит ему где-то пристроиться, он тут же идёт в гору очень быстро"] * 2
prompt_tokens: Optional[list[Path]] = [Path("prompt_ORD_1.npy"), ] * 2
num_samples: int = 1
max_new_tokens: int = 256
top_p: int = 0.7
repetition_penalty: float = 1.2
temperature: float = 0.7
checkpoint_path: Path = Path("checkpoints/fish-speech-1.5")
device: str = "cuda"
compile: bool = False
seed: int = 1
half: bool = False
output_dir: Path = Path("temp")
batch_size = 2
os.makedirs(output_dir, exist_ok=True)
precision = torch.half if half else torch.bfloat16
if prompt_text is not None and len(prompt_text) != len(prompt_tokens):
raise ValueError(
f"Number of prompt text ({len(prompt_text)}) and prompt tokens ({len(prompt_tokens)}) should be the same"
)
logger.info("Loading model ...")
t0 = time.time()
model, decode_one_token = load_model(
checkpoint_path, device, precision, compile=compile
)
with torch.device(device):
model.setup_caches(
max_batch_size=batch_size,
max_seq_len=model.config.max_seq_len,
dtype=next(model.parameters()).dtype,
)
if torch.cuda.is_available():
torch.cuda.synchronize()
logger.info(f"Time to load model: {time.time() - t0:.02f} seconds")
if prompt_tokens is not None:
prompt_tokens_ = [torch.from_numpy(np.load(p)).to(device) for p in prompt_tokens]
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)
generator = generate_batch(
model=model,
device=device,
decode_one_token=decode_one_token,
text=text,
num_samples=num_samples,
max_new_tokens=max_new_tokens,
top_p=top_p,
repetition_penalty=repetition_penalty,
temperature=temperature,
compile=compile,
prompt_text=prompt_text,
prompt_tokens=prompt_tokens_,
)
idx = 0
codes = []
codess = []
for response in generator:
if response.action == "sample":
codes.append(response.codes)
logger.info(f"Sampled text: {response.text}")
elif response.action == "next":
if codes:
codes_npy_path = os.path.join(output_dir, f"codes_{idx}.npy")
np.save(codes_npy_path, torch.cat(codes, dim=2 if batched else 1).cpu().numpy())
logger.info(f"Saved codes to {codes_npy_path}")
codess.extend(codes)
logger.info(f"Next sample")
codes = []
idx += 1
else:
logger.error(f"Error: {response}")
np.save("temp/batched_gen", torch.cat(codess, dim=2).cpu().numpy()[0])
@YangGuangzhaoJJJacky
Copy link

thank for your work!!! Great help! I pull your branch and run this code, but only two token is outputted

2025-04-18 16:24:15.587 | INFO     | __main__:<module>:27 - Loading model ...
2025-04-18 16:24:19.194 | INFO     | fish_speech.models.text2semantic.inference:load_model:816 - Restored model from checkpoint
2025-04-18 16:24:19.194 | INFO     | fish_speech.models.text2semantic.inference:load_model:822 - Using DualARTransformer
2025-04-18 16:24:19.203 | INFO     | __main__:<module>:41 - Time to load model: 3.62 seconds
2025-04-18 16:24:19.223 | INFO     | fish_speech.models.text2semantic.inference:generate_batch:1094 - Encoded text: こちらは。 こちらは。
2025-04-18 16:24:19.249 | INFO     | fish_speech.models.text2semantic.inference:generate_batch:1094 - Encoded text: エーアイですが、人間のように会話できます。
2025-04-18 16:24:19.250 | INFO     | fish_speech.models.text2semantic.inference:generate_batch:1103 - Encoded text batch of shape [2, 9, 9], and prompt batch [2, 9, 1742]
2025-04-18 16:24:19.250 | INFO     | fish_speech.models.text2semantic.inference:generate_batch:1125 - Generating sentence 1/1 of sample 1/1
  0%|                                                                                                                       | 0/255 [00:00<?, ?it/s]/home/recosele/miniconda3/envs/fish-speech/lib/python3.12/contextlib.py:105: FutureWarning: `torch.backends.cuda.sdp_kernel()` is deprecated. In the future, this context manager will be removed. Please see `torch.nn.attention.sdpa_kernel()` for the new context manager, with updated signature.
  self.gen = func(*args, **kwds)
  0%|                                                                                                                       | 0/255 [00:00<?, ?it/s]
2025-04-18 16:24:19.660 | INFO     | fish_speech.models.text2semantic.inference:generate_batch:1180 - Generated 2 tokens in 0.41 seconds, 4.87 tokens/sec
2025-04-18 16:24:19.660 | INFO     | fish_speech.models.text2semantic.inference:generate_batch:1183 - Bandwidth achieved: 3.11 GB/s
2025-04-18 16:24:19.661 | INFO     | fish_speech.models.text2semantic.inference:generate_batch:1188 - GPU Memory used: 2.06 GB
2025-04-18 16:24:19.661 | INFO     | __main__:<module>:74 - Sampled text: こちらは。 こちらは。
2025-04-18 16:24:19.661 | INFO     | __main__:<module>:79 - Saved codes to temp/codes_0.npy
2025-04-18 16:24:19.661 | INFO     | __main__:<module>:81 - Next sample

would u have me fix the problem please

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment