Last active
July 15, 2025 15:34
-
-
Save mizoru/90644a62b7646f763e89ca2182763809 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from fish_speech.models.text2semantic.inference import * | |
text: str = ["Что? Он никогда не работал?", "Никогда бы о таком не подумал раньше"] | |
prompt_text: Optional[list[str]] = ["Что он никогда вот не работал, но стоит ему где-то пристроиться, он тут же идёт в гору очень быстро"] * 2 | |
prompt_tokens: Optional[list[Path]] = [Path("prompt_ORD_1.npy"), ] * 2 | |
num_samples: int = 1 | |
max_new_tokens: int = 256 | |
top_p: int = 0.7 | |
repetition_penalty: float = 1.2 | |
temperature: float = 0.7 | |
checkpoint_path: Path = Path("checkpoints/fish-speech-1.5") | |
device: str = "cuda" | |
compile: bool = False | |
seed: int = 1 | |
half: bool = False | |
output_dir: Path = Path("temp") | |
batch_size = 2 | |
os.makedirs(output_dir, exist_ok=True) | |
precision = torch.half if half else torch.bfloat16 | |
if prompt_text is not None and len(prompt_text) != len(prompt_tokens): | |
raise ValueError( | |
f"Number of prompt text ({len(prompt_text)}) and prompt tokens ({len(prompt_tokens)}) should be the same" | |
) | |
logger.info("Loading model ...") | |
t0 = time.time() | |
model, decode_one_token = load_model( | |
checkpoint_path, device, precision, compile=compile | |
) | |
with torch.device(device): | |
model.setup_caches( | |
max_batch_size=batch_size, | |
max_seq_len=model.config.max_seq_len, | |
dtype=next(model.parameters()).dtype, | |
) | |
if torch.cuda.is_available(): | |
torch.cuda.synchronize() | |
logger.info(f"Time to load model: {time.time() - t0:.02f} seconds") | |
if prompt_tokens is not None: | |
prompt_tokens_ = [torch.from_numpy(np.load(p)).to(device) for p in prompt_tokens] | |
torch.manual_seed(seed) | |
if torch.cuda.is_available(): | |
torch.cuda.manual_seed(seed) | |
generator = generate_batch( | |
model=model, | |
device=device, | |
decode_one_token=decode_one_token, | |
text=text, | |
num_samples=num_samples, | |
max_new_tokens=max_new_tokens, | |
top_p=top_p, | |
repetition_penalty=repetition_penalty, | |
temperature=temperature, | |
compile=compile, | |
prompt_text=prompt_text, | |
prompt_tokens=prompt_tokens_, | |
) | |
idx = 0 | |
codes = [] | |
codess = [] | |
for response in generator: | |
if response.action == "sample": | |
codes.append(response.codes) | |
logger.info(f"Sampled text: {response.text}") | |
elif response.action == "next": | |
if codes: | |
codes_npy_path = os.path.join(output_dir, f"codes_{idx}.npy") | |
np.save(codes_npy_path, torch.cat(codes, dim=2 if batched else 1).cpu().numpy()) | |
logger.info(f"Saved codes to {codes_npy_path}") | |
codess.extend(codes) | |
logger.info(f"Next sample") | |
codes = [] | |
idx += 1 | |
else: | |
logger.error(f"Error: {response}") | |
np.save("temp/batched_gen", torch.cat(codess, dim=2).cpu().numpy()[0]) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
thank for your work!!! Great help! I pull your branch and run this code, but only two token is outputted
would u have me fix the problem please