Skip to content

Instantly share code, notes, and snippets.

@KennyVaneetvelde
Created August 7, 2024 14:20
Show Gist options
  • Save KennyVaneetvelde/3a0882f84c4d5b5fd5882db290343b0a to your computer and use it in GitHub Desktop.
Save KennyVaneetvelde/3a0882f84c4d5b5fd5882db290343b0a to your computer and use it in GitHub Desktop.
Atomic Agents + Streamlit Mermaid Diagram Assistant
import instructor
import openai
import streamlit as st
from pydantic import Field
from dataclasses import dataclass
import streamlit.components.v1 as components
from atomic_agents.agents.base_agent import BaseIOSchema, BaseAgent, BaseAgentConfig
from atomic_agents.lib.components.system_prompt_generator import SystemPromptGenerator
# Initialize OpenAI client
client = instructor.from_openai(openai.OpenAI())
##############################
# AGENT INPUT/OUTPUT SCHEMAS #
##############################
class MermaidAgentInputSchema(BaseIOSchema):
"""Input schema for the MermaidAgent."""
prompt: str = Field(
...,
description="A description of the flowchart to be created using MermaidJS.",
)
class MermaidAgentOutputSchema(BaseIOSchema):
"""Output schema for the MermaidAgent."""
internal_reasoning: list[str] = Field(
..., description="Internal reasoning steps for the flowchart creation."
)
mermaid_flowchart: str = Field(
...,
description="The generated MermaidJS schema for the flowchart. This schema must be rendered using MermaidJS and be completely compatible with the MermaidJS library.",
)
message_to_user: str = Field(
..., description="A message to the user about the flowchart creation."
)
#############################
# MERMAID AGENT INITIALIZER #
#############################
def initialize_mermaid_agent():
return BaseAgent(
BaseAgentConfig(
client=client,
model="gpt-4o-mini",
system_prompt_generator=SystemPromptGenerator(
background=[
"You are an AI assistant specialized in generating MermaidJS flowchart schemas.",
"You have extensive knowledge of MermaidJS syntax and flowchart types.",
],
steps=[
"You will receive a description of a flowchart to be created.",
"Analyze the description and determine the most appropriate MermaidJS flowchart type.",
"Generate the MermaidJS schemas for the described flowchart.",
],
output_instructions=[
"Provide only the MermaidJS schemas without any additional explanation.",
"Ensure the schema is valid and follows MermaidJS syntax. Do not include anything that is not related to the flowchart creation.",
"Include appropriate styling and layout options to enhance readability.",
],
),
input_schema=MermaidAgentInputSchema,
output_schema=MermaidAgentOutputSchema,
)
)
##############################
# MERMAID RENDERING FUNCTION #
##############################
def render_mermaid(mermaid_flowchart):
components.html(
f"""
<link rel="stylesheet" href="https://cdnjs.cloudflare.com/ajax/libs/font-awesome/5.15.3/css/all.min.css">
<script src="https://cdn.jsdelivr.net/npm/mermaid@latest/dist/mermaid.min.js"></script>
<div class="mermaid-container" style="overflow-y: auto; max-height: 750px;">
<div class="mermaid">
{mermaid_flowchart}
</div>
</div>
<script>
mermaid.initialize({{
startOnLoad: true,
fontFamily: 'monospace, sans-serif',
flowchart: {{
htmlLabels: true,
useMaxWidth: true,
}},
securityLevel: 'loose',
}});
mermaid.parseError = function(err, hash) {{
console.error('Mermaid error:', err);
}};
</script>
""",
height=750,
)
######################
# MAIN APPLICATION #
######################
def main():
st.set_page_config(layout="wide", page_title="AI-Powered Flowchart Creator")
st.title("AI-Powered Flowchart Creator")
# Initialize MermaidAgent
if "mermaid_agent" not in st.session_state:
st.session_state.mermaid_agent = initialize_mermaid_agent()
# Initialize chat history
if "chat_history" not in st.session_state:
st.session_state.chat_history = []
# Create two columns
col1, col2 = st.columns([3, 2])
#################################
# LEFT COLUMN: CHAT INTERFACE #
#################################
with col1:
st.header("Chat with AI Flowchart Creator")
# Create a container for chat history
chat_container = st.container()
# User input at the bottom
user_prompt = st.chat_input("Enter your flowchart description:")
if user_prompt:
# Add user message to chat history
st.session_state.chat_history.append(
{"role": "user", "content": user_prompt}
)
with st.spinner("Generating flowchart..."):
response = st.session_state.mermaid_agent.run(
MermaidAgentInputSchema(prompt=user_prompt)
)
# Add AI message to chat history
st.session_state.chat_history.append(
{"role": "assistant", "content": response.message_to_user}
)
# Store the generated flowchart in session state
st.session_state.generated_flowchart = response.mermaid_flowchart
st.session_state.internal_reasoning = response.internal_reasoning
st.success("Flowchart generated successfully!")
# Display chat history
with chat_container:
for message in st.session_state.chat_history:
with st.chat_message(message["role"]):
st.write(message["content"])
#################################
# RIGHT COLUMN: FLOWCHART VIEW #
#################################
with col2:
st.header("Generated Flowchart")
if "generated_flowchart" in st.session_state:
# Render Mermaid flowchart
render_mermaid(st.session_state.generated_flowchart)
# Collapsible Mermaid code
with st.expander("Mermaid Code", expanded=False):
st.code(st.session_state.generated_flowchart, language="mermaid")
# Collapsible internal reasoning
with st.expander("Internal Reasoning", expanded=False):
for step in st.session_state.internal_reasoning:
st.write(step)
else:
st.info("Generate a flowchart to see it displayed here.")
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment