Skip to content

Instantly share code, notes, and snippets.

@jepjoo
Created July 4, 2025 12:34
Show Gist options
  • Select an option

  • Save jepjoo/7ab6da43c3e51923eeaf278eac47c9c9 to your computer and use it in GitHub Desktop.

Select an option

Save jepjoo/7ab6da43c3e51923eeaf278eac47c9c9 to your computer and use it in GitHub Desktop.
Fix for llama-server
import os
import re
from copy import deepcopy
from functools import cache
from typing import Any, AsyncIterator, Protocol, cast
from mistralai import Mistral
from openai import AsyncOpenAI, OpenAI
from unmute.kyutai_constants import LLM_SERVER
from ..kyutai_constants import KYUTAI_LLM_MODEL
INTERRUPTION_CHAR = "—" # em-dash
USER_SILENCE_MARKER = "..."
def preprocess_messages_for_llm(
chat_history: list[dict[str, str]],
) -> list[dict[str, str]]:
output = []
for message in chat_history:
message = deepcopy(message)
# Sometimes, an interruption happens before the LLM can say anything at all.
# In that case, we're left with a message with only INTERRUPTION_CHAR.
# Simplify by removing.
if message["content"].replace(INTERRUPTION_CHAR, "") == "":
continue
if output and message["role"] == output[-1]["role"]:
output[-1]["content"] += " " + message["content"]
else:
output.append(message)
def role_at(index: int) -> str | None:
if index >= len(output):
return None
return output[index]["role"]
if role_at(0) == "system" and role_at(1) in [None, "assistant"]:
# Some LLMs, like Gemma, get confused if the assistant message goes before user
# messages, so add a dummy user message.
output = [output[0]] + [{"role": "user", "content": "Hello."}] + output[1:]
for message in chat_history:
if (
message["role"] == "user"
and message["content"].startswith(USER_SILENCE_MARKER)
and message["content"] != USER_SILENCE_MARKER
):
# This happens when the user is silent but then starts talking again after
# the silence marker was inserted but before the LLM could respond.
# There are special instructions in the system prompt about how to handle
# the silence marker, so remove the marker from the message to not confuse
# the LLM
message["content"] = message["content"][len(USER_SILENCE_MARKER) :]
return output
async def rechunk_to_words(iterator: AsyncIterator[str]) -> AsyncIterator[str]:
"""Rechunk the stream of text to whole words.
Otherwise the TTS doesn't know where word boundaries are and will mispronounce
split words.
The spaces will be included with the next word, so "foo bar baz" will be split into
"foo", " bar", " baz".
Multiple space-like characters will be merged to a single space.
"""
buffer = ""
space_re = re.compile(r"\s+")
prefix = ""
async for delta in iterator:
buffer = buffer + delta
while True:
match = space_re.search(buffer)
if match is None:
break
chunk = buffer[: match.start()]
buffer = buffer[match.end() :]
if chunk != "":
yield prefix + chunk
prefix = " "
if buffer != "":
yield prefix + buffer
class LLMStream(Protocol):
async def chat_completion(
self, messages: list[dict[str, str]]
) -> AsyncIterator[str]:
"""Get a chat completion from the LLM."""
...
class MistralStream:
def __init__(self):
self.current_message_index = 0
self.mistral = Mistral(api_key=os.environ["MISTRAL_API_KEY"])
async def chat_completion(
self, messages: list[dict[str, str]]
) -> AsyncIterator[str]:
event_stream = await self.mistral.chat.stream_async(
model="mistral-large-latest",
messages=cast(Any, messages), # It's too annoying to type this properly
temperature=1.0,
)
async for event in event_stream:
delta = event.data.choices[0].delta.content
assert isinstance(delta, str) # make Pyright happy
yield delta
def get_openai_client(server_url: str = LLM_SERVER) -> AsyncOpenAI:
return AsyncOpenAI(api_key="EMPTY", base_url=server_url + "/v1")
@cache
def autoselect_model() -> str:
if KYUTAI_LLM_MODEL is not None:
return KYUTAI_LLM_MODEL
client_sync = OpenAI(api_key="EMPTY", base_url=get_openai_client().base_url)
models = client_sync.models.list()
if len(models.data) != 1:
raise ValueError("There are multiple models available. Please specify one.")
return models.data[0].id
class VLLMStream:
def __init__(
self,
client: AsyncOpenAI,
temperature: float = 1.0,
model: str | None = None, # Added model as an optional parameter
):
"""
If `model` is None, it will look at the available models, and if there is only
one model, it will use that one. Otherwise, it will raise.
"""
self.client = client
# Use the provided model, or fallback to autoselect if not provided
self.model = model if model is not None else autoselect_model()
self.temperature = temperature
async def chat_completion(
self, messages: list[dict[str, str]]
) -> AsyncIterator[str]:
stream = await self.client.chat.completions.create(
model=self.model,
messages=cast(Any, messages), # Cast and hope for the best
stream=True,
temperature=self.temperature,
)
async with stream:
async for chunk in stream:
# Defensive check: Ensure choices and delta exist
if not chunk.choices or not chunk.choices[0].delta:
continue # Skip malformed or empty chunks
chunk_content = chunk.choices[0].delta.content
# This is the key change: Handle None content
if chunk_content is not None:
# If you're absolutely sure it should always be a string when not None,
# you can keep the assert for type checking during development,
# but in production, you might want to log/handle if it's not.
assert isinstance(chunk_content, str), f"Expected string but got {type(chunk_content)}: {chunk_content}"
yield chunk_content
# If chunk_content is None, we simply don't yield anything for this chunk.
# This is common for initial or final chunks in OpenAI streaming.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment