Created
April 16, 2024 02:31
-
-
Save RyanBalfanz/cd08b7402594fa91831bf8d54c76e0ec to your computer and use it in GitHub Desktop.
Initial LLM WebUI
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import json | |
import urllib.parse | |
from dataclasses import InitVar, dataclass | |
from datetime import datetime, timedelta | |
from typing import Protocol | |
import httpx | |
import llm | |
import streamlit as st | |
@dataclass | |
class Config: | |
app_title: str = "LLM WebUI" | |
datasette_base_url: str = "http://127.0.0.1:8001/" | |
loading_text: str = "Loading…" | |
success_icon: str = "🎉" | |
@dataclass | |
class Conversation: | |
id: str | |
name: str | |
model: str | |
@dataclass | |
class Response: | |
id: str | |
model: str | |
prompt: str | |
system: str | |
prompt_json: dict | None | |
options_json: dict | |
response: str | |
response_json: dict | None | |
conversation_id: str | |
duration_ms: timedelta | |
datetime_utc: datetime | |
@dataclass | |
class DatasetteLogsClient: | |
base_url: str | |
def urljoin(self, s: str): | |
"""Join a relative URL with the base URL.""" | |
return urllib.parse.urljoin(self.base_url, s) | |
def get_conversations(self): | |
"""Fetch all conversations from the Datasette logs.""" | |
url = self.urljoin("/logs/conversations.json") | |
response = httpx.get(url) | |
response.raise_for_status() | |
for row in (rows := response.json()["rows"]): | |
yield Conversation(id=row[0], name=row[1], model=row[2]) | |
def get_responses(self): | |
"""Fetch all responses from the Datasette logs.""" | |
raise NotImplementedError("This method is not yet implemented.") | |
def get_responses_by_conversation_id(self, conversation_id): | |
"""Fetch all responses for a given conversation ID from the Datasette logs.""" | |
url = self.urljoin("/logs/responses.json") | |
response = httpx.get(url, params={"conversation_id": conversation_id}) | |
response.raise_for_status() | |
data = response.json() | |
if "rows" not in data: | |
raise Exception(f"Unexpected response: {data}") | |
for row in (rows := data["rows"]): | |
assert ( | |
row[8] == conversation_id | |
) # Sanity check that the conversation ID matches. | |
yield Response( | |
id=row[0], | |
model=row[1], | |
prompt=row[2], | |
system=row[3], | |
prompt_json=json.loads(row[4]) if row[7] is not None else None, | |
options_json=json.loads(row[5]), | |
response=row[6], | |
response_json=json.loads(row[7]) if row[7] is not None else None, | |
conversation_id=row[8], | |
duration_ms=timedelta(milliseconds=row[9]), | |
datetime_utc=datetime.fromisoformat(row[10]), | |
) | |
class Repository[T](Protocol): | |
def all(self) -> list[T] | None: ... | |
def get(self, id: str) -> T | None: ... | |
@dataclass | |
class ConversationsRepository(Repository[Conversation]): | |
base_url: InitVar[str] | |
client: DatasetteLogsClient | None = None | |
def __post_init__(self, base_url): | |
if self.client is None and base_url is not None: | |
self.client = DatasetteLogsClient(base_url) | |
def all(self) -> list[Conversation]: | |
if self.client is None: | |
raise Exception("Client is not initialized") | |
return list(self.client.get_conversations()) | |
def get(self, id: str) -> Conversation | None: | |
for c in self.all(): | |
if c.id != id: | |
continue | |
return c | |
@dataclass | |
class ResponsesRepository(Repository[Response]): | |
base_url: InitVar[str] | |
client: DatasetteLogsClient | None = None | |
def __post_init__(self, base_url): | |
if self.client is None and base_url is not None: | |
self.client = DatasetteLogsClient(base_url) | |
def all(self) -> list[Response]: | |
if self.client is None: | |
raise Exception("Client is not initialized") | |
return list(self.client.get_responses()) | |
def get(self, id: str) -> Response | None: | |
for c in self.all(): | |
if c.id != id: | |
continue | |
return c | |
def get_by_conversation_id(self, conversation_id: str) -> list[Response]: | |
if self.client is None: | |
raise Exception("Client is not initialized") | |
return list(self.client.get_responses_by_conversation_id(conversation_id)) | |
def init(config: Config): | |
if "config" not in st.session_state: | |
st.session_state.config = config | |
if "conversations" not in st.session_state: | |
st.session_state.conversations = [] | |
if "selected_conversation_id" not in st.session_state: | |
st.session_state.selected_conversation_id = None | |
if "responses" not in st.session_state: | |
st.session_state.responses = [] | |
def get_conversations(): | |
repo = ConversationsRepository(st.session_state.config.datasette_base_url) | |
for r in repo.all(): | |
yield r | |
def get_responses_by_conversation_id(conversation_id: str): | |
repo = ResponsesRepository(st.session_state.config.datasette_base_url) | |
for r in repo.get_by_conversation_id(conversation_id): | |
yield r | |
def conversation_on_change(): | |
with st.spinner(st.session_state.config.loading_text): | |
st.session_state.responses = list( | |
get_responses_by_conversation_id(st.session_state.selected_conversation_id) | |
) | |
st.success( | |
f"Fetched {len(st.session_state.responses)} responses", | |
icon=st.session_state.config.success_icon, | |
) | |
def main() -> None: | |
init(Config()) | |
with st.sidebar: | |
st.title(st.session_state.config.app_title) | |
with st.spinner(st.session_state.config.loading_text): | |
st.session_state.conversations = list(get_conversations()) | |
st.success( | |
f"Fetched {len(st.session_state.conversations)} conversations", | |
icon=st.session_state.config.success_icon, | |
) | |
if not st.session_state.conversations: | |
st.stop() | |
def on_click(conversation): | |
st.session_state.update({"selected_conversation_id": conversation.id}) | |
conversation_on_change() | |
for conversation in st.session_state.conversations: | |
st.button( | |
conversation.name, | |
key=f"conversation_{conversation.id}", | |
on_click=lambda c=conversation: on_click(c), | |
) | |
if ( | |
"selected_conversation_id" in st.session_state | |
and st.session_state.selected_conversation_id | |
): | |
st.title(f"Conversation: {st.session_state.selected_conversation_id}") | |
st.write( | |
"\n".join( | |
[ | |
f"|Conversation|Model|Last Response|", | |
f"|---|---|---|", | |
"|" | |
+ "|".join( | |
[ | |
st.session_state.selected_conversation_id, | |
st.session_state.responses[-1].model, | |
st.session_state.responses[-1].datetime_utc.isoformat(), | |
] | |
) | |
+ "|", | |
] | |
) | |
) | |
# Display all prompt-response interactions for the selected conversation. | |
if "responses" in st.session_state: | |
for response in st.session_state.responses: | |
# Display the prompt. | |
name = "user" | |
if response and response.prompt_json and "messages" in response.prompt_json: | |
name = response.prompt_json["messages"][0]["role"] | |
with st.chat_message(name): | |
st.markdown(response.prompt) | |
# Display the response. | |
name = "ai" | |
with st.chat_message(name): | |
st.markdown(response.response) | |
if __name__ == "__main__": | |
raise SystemExit(main()) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
First public mention was at https://discord.com/channels/823971286308356157/1128504153841336370/1229617445115789353.