Created
July 4, 2025 12:34
-
-
Save jepjoo/7ab6da43c3e51923eeaf278eac47c9c9 to your computer and use it in GitHub Desktop.
Fix for llama-server
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import os | |
| import re | |
| from copy import deepcopy | |
| from functools import cache | |
| from typing import Any, AsyncIterator, Protocol, cast | |
| from mistralai import Mistral | |
| from openai import AsyncOpenAI, OpenAI | |
| from unmute.kyutai_constants import LLM_SERVER | |
| from ..kyutai_constants import KYUTAI_LLM_MODEL | |
| INTERRUPTION_CHAR = "—" # em-dash | |
| USER_SILENCE_MARKER = "..." | |
| def preprocess_messages_for_llm( | |
| chat_history: list[dict[str, str]], | |
| ) -> list[dict[str, str]]: | |
| output = [] | |
| for message in chat_history: | |
| message = deepcopy(message) | |
| # Sometimes, an interruption happens before the LLM can say anything at all. | |
| # In that case, we're left with a message with only INTERRUPTION_CHAR. | |
| # Simplify by removing. | |
| if message["content"].replace(INTERRUPTION_CHAR, "") == "": | |
| continue | |
| if output and message["role"] == output[-1]["role"]: | |
| output[-1]["content"] += " " + message["content"] | |
| else: | |
| output.append(message) | |
| def role_at(index: int) -> str | None: | |
| if index >= len(output): | |
| return None | |
| return output[index]["role"] | |
| if role_at(0) == "system" and role_at(1) in [None, "assistant"]: | |
| # Some LLMs, like Gemma, get confused if the assistant message goes before user | |
| # messages, so add a dummy user message. | |
| output = [output[0]] + [{"role": "user", "content": "Hello."}] + output[1:] | |
| for message in chat_history: | |
| if ( | |
| message["role"] == "user" | |
| and message["content"].startswith(USER_SILENCE_MARKER) | |
| and message["content"] != USER_SILENCE_MARKER | |
| ): | |
| # This happens when the user is silent but then starts talking again after | |
| # the silence marker was inserted but before the LLM could respond. | |
| # There are special instructions in the system prompt about how to handle | |
| # the silence marker, so remove the marker from the message to not confuse | |
| # the LLM | |
| message["content"] = message["content"][len(USER_SILENCE_MARKER) :] | |
| return output | |
| async def rechunk_to_words(iterator: AsyncIterator[str]) -> AsyncIterator[str]: | |
| """Rechunk the stream of text to whole words. | |
| Otherwise the TTS doesn't know where word boundaries are and will mispronounce | |
| split words. | |
| The spaces will be included with the next word, so "foo bar baz" will be split into | |
| "foo", " bar", " baz". | |
| Multiple space-like characters will be merged to a single space. | |
| """ | |
| buffer = "" | |
| space_re = re.compile(r"\s+") | |
| prefix = "" | |
| async for delta in iterator: | |
| buffer = buffer + delta | |
| while True: | |
| match = space_re.search(buffer) | |
| if match is None: | |
| break | |
| chunk = buffer[: match.start()] | |
| buffer = buffer[match.end() :] | |
| if chunk != "": | |
| yield prefix + chunk | |
| prefix = " " | |
| if buffer != "": | |
| yield prefix + buffer | |
| class LLMStream(Protocol): | |
| async def chat_completion( | |
| self, messages: list[dict[str, str]] | |
| ) -> AsyncIterator[str]: | |
| """Get a chat completion from the LLM.""" | |
| ... | |
| class MistralStream: | |
| def __init__(self): | |
| self.current_message_index = 0 | |
| self.mistral = Mistral(api_key=os.environ["MISTRAL_API_KEY"]) | |
| async def chat_completion( | |
| self, messages: list[dict[str, str]] | |
| ) -> AsyncIterator[str]: | |
| event_stream = await self.mistral.chat.stream_async( | |
| model="mistral-large-latest", | |
| messages=cast(Any, messages), # It's too annoying to type this properly | |
| temperature=1.0, | |
| ) | |
| async for event in event_stream: | |
| delta = event.data.choices[0].delta.content | |
| assert isinstance(delta, str) # make Pyright happy | |
| yield delta | |
| def get_openai_client(server_url: str = LLM_SERVER) -> AsyncOpenAI: | |
| return AsyncOpenAI(api_key="EMPTY", base_url=server_url + "/v1") | |
| @cache | |
| def autoselect_model() -> str: | |
| if KYUTAI_LLM_MODEL is not None: | |
| return KYUTAI_LLM_MODEL | |
| client_sync = OpenAI(api_key="EMPTY", base_url=get_openai_client().base_url) | |
| models = client_sync.models.list() | |
| if len(models.data) != 1: | |
| raise ValueError("There are multiple models available. Please specify one.") | |
| return models.data[0].id | |
| class VLLMStream: | |
| def __init__( | |
| self, | |
| client: AsyncOpenAI, | |
| temperature: float = 1.0, | |
| model: str | None = None, # Added model as an optional parameter | |
| ): | |
| """ | |
| If `model` is None, it will look at the available models, and if there is only | |
| one model, it will use that one. Otherwise, it will raise. | |
| """ | |
| self.client = client | |
| # Use the provided model, or fallback to autoselect if not provided | |
| self.model = model if model is not None else autoselect_model() | |
| self.temperature = temperature | |
| async def chat_completion( | |
| self, messages: list[dict[str, str]] | |
| ) -> AsyncIterator[str]: | |
| stream = await self.client.chat.completions.create( | |
| model=self.model, | |
| messages=cast(Any, messages), # Cast and hope for the best | |
| stream=True, | |
| temperature=self.temperature, | |
| ) | |
| async with stream: | |
| async for chunk in stream: | |
| # Defensive check: Ensure choices and delta exist | |
| if not chunk.choices or not chunk.choices[0].delta: | |
| continue # Skip malformed or empty chunks | |
| chunk_content = chunk.choices[0].delta.content | |
| # This is the key change: Handle None content | |
| if chunk_content is not None: | |
| # If you're absolutely sure it should always be a string when not None, | |
| # you can keep the assert for type checking during development, | |
| # but in production, you might want to log/handle if it's not. | |
| assert isinstance(chunk_content, str), f"Expected string but got {type(chunk_content)}: {chunk_content}" | |
| yield chunk_content | |
| # If chunk_content is None, we simply don't yield anything for this chunk. | |
| # This is common for initial or final chunks in OpenAI streaming. |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment