Skip to content

Instantly share code, notes, and snippets.

@RyanBalfanz
Created April 16, 2024 02:31
Show Gist options
  • Save RyanBalfanz/cd08b7402594fa91831bf8d54c76e0ec to your computer and use it in GitHub Desktop.
Save RyanBalfanz/cd08b7402594fa91831bf8d54c76e0ec to your computer and use it in GitHub Desktop.
Initial LLM WebUI
import json
import urllib.parse
from dataclasses import InitVar, dataclass
from datetime import datetime, timedelta
from typing import Protocol
import httpx
import llm
import streamlit as st
@dataclass
class Config:
app_title: str = "LLM WebUI"
datasette_base_url: str = "http://127.0.0.1:8001/"
loading_text: str = "Loading…"
success_icon: str = "🎉"
@dataclass
class Conversation:
id: str
name: str
model: str
@dataclass
class Response:
id: str
model: str
prompt: str
system: str
prompt_json: dict | None
options_json: dict
response: str
response_json: dict | None
conversation_id: str
duration_ms: timedelta
datetime_utc: datetime
@dataclass
class DatasetteLogsClient:
base_url: str
def urljoin(self, s: str):
"""Join a relative URL with the base URL."""
return urllib.parse.urljoin(self.base_url, s)
def get_conversations(self):
"""Fetch all conversations from the Datasette logs."""
url = self.urljoin("/logs/conversations.json")
response = httpx.get(url)
response.raise_for_status()
for row in (rows := response.json()["rows"]):
yield Conversation(id=row[0], name=row[1], model=row[2])
def get_responses(self):
"""Fetch all responses from the Datasette logs."""
raise NotImplementedError("This method is not yet implemented.")
def get_responses_by_conversation_id(self, conversation_id):
"""Fetch all responses for a given conversation ID from the Datasette logs."""
url = self.urljoin("/logs/responses.json")
response = httpx.get(url, params={"conversation_id": conversation_id})
response.raise_for_status()
data = response.json()
if "rows" not in data:
raise Exception(f"Unexpected response: {data}")
for row in (rows := data["rows"]):
assert (
row[8] == conversation_id
) # Sanity check that the conversation ID matches.
yield Response(
id=row[0],
model=row[1],
prompt=row[2],
system=row[3],
prompt_json=json.loads(row[4]) if row[7] is not None else None,
options_json=json.loads(row[5]),
response=row[6],
response_json=json.loads(row[7]) if row[7] is not None else None,
conversation_id=row[8],
duration_ms=timedelta(milliseconds=row[9]),
datetime_utc=datetime.fromisoformat(row[10]),
)
class Repository[T](Protocol):
def all(self) -> list[T] | None: ...
def get(self, id: str) -> T | None: ...
@dataclass
class ConversationsRepository(Repository[Conversation]):
base_url: InitVar[str]
client: DatasetteLogsClient | None = None
def __post_init__(self, base_url):
if self.client is None and base_url is not None:
self.client = DatasetteLogsClient(base_url)
def all(self) -> list[Conversation]:
if self.client is None:
raise Exception("Client is not initialized")
return list(self.client.get_conversations())
def get(self, id: str) -> Conversation | None:
for c in self.all():
if c.id != id:
continue
return c
@dataclass
class ResponsesRepository(Repository[Response]):
base_url: InitVar[str]
client: DatasetteLogsClient | None = None
def __post_init__(self, base_url):
if self.client is None and base_url is not None:
self.client = DatasetteLogsClient(base_url)
def all(self) -> list[Response]:
if self.client is None:
raise Exception("Client is not initialized")
return list(self.client.get_responses())
def get(self, id: str) -> Response | None:
for c in self.all():
if c.id != id:
continue
return c
def get_by_conversation_id(self, conversation_id: str) -> list[Response]:
if self.client is None:
raise Exception("Client is not initialized")
return list(self.client.get_responses_by_conversation_id(conversation_id))
def init(config: Config):
if "config" not in st.session_state:
st.session_state.config = config
if "conversations" not in st.session_state:
st.session_state.conversations = []
if "selected_conversation_id" not in st.session_state:
st.session_state.selected_conversation_id = None
if "responses" not in st.session_state:
st.session_state.responses = []
def get_conversations():
repo = ConversationsRepository(st.session_state.config.datasette_base_url)
for r in repo.all():
yield r
def get_responses_by_conversation_id(conversation_id: str):
repo = ResponsesRepository(st.session_state.config.datasette_base_url)
for r in repo.get_by_conversation_id(conversation_id):
yield r
def conversation_on_change():
with st.spinner(st.session_state.config.loading_text):
st.session_state.responses = list(
get_responses_by_conversation_id(st.session_state.selected_conversation_id)
)
st.success(
f"Fetched {len(st.session_state.responses)} responses",
icon=st.session_state.config.success_icon,
)
def main() -> None:
init(Config())
with st.sidebar:
st.title(st.session_state.config.app_title)
with st.spinner(st.session_state.config.loading_text):
st.session_state.conversations = list(get_conversations())
st.success(
f"Fetched {len(st.session_state.conversations)} conversations",
icon=st.session_state.config.success_icon,
)
if not st.session_state.conversations:
st.stop()
def on_click(conversation):
st.session_state.update({"selected_conversation_id": conversation.id})
conversation_on_change()
for conversation in st.session_state.conversations:
st.button(
conversation.name,
key=f"conversation_{conversation.id}",
on_click=lambda c=conversation: on_click(c),
)
if (
"selected_conversation_id" in st.session_state
and st.session_state.selected_conversation_id
):
st.title(f"Conversation: {st.session_state.selected_conversation_id}")
st.write(
"\n".join(
[
f"|Conversation|Model|Last Response|",
f"|---|---|---|",
"|"
+ "|".join(
[
st.session_state.selected_conversation_id,
st.session_state.responses[-1].model,
st.session_state.responses[-1].datetime_utc.isoformat(),
]
)
+ "|",
]
)
)
# Display all prompt-response interactions for the selected conversation.
if "responses" in st.session_state:
for response in st.session_state.responses:
# Display the prompt.
name = "user"
if response and response.prompt_json and "messages" in response.prompt_json:
name = response.prompt_json["messages"][0]["role"]
with st.chat_message(name):
st.markdown(response.prompt)
# Display the response.
name = "ai"
with st.chat_message(name):
st.markdown(response.response)
if __name__ == "__main__":
raise SystemExit(main())
@RyanBalfanz
Copy link
Author

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment