Created
August 7, 2024 14:20
-
-
Save KennyVaneetvelde/3a0882f84c4d5b5fd5882db290343b0a to your computer and use it in GitHub Desktop.
Atomic Agents + Streamlit Mermaid Diagram Assistant
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import instructor | |
import openai | |
import streamlit as st | |
from pydantic import Field | |
from dataclasses import dataclass | |
import streamlit.components.v1 as components | |
from atomic_agents.agents.base_agent import BaseIOSchema, BaseAgent, BaseAgentConfig | |
from atomic_agents.lib.components.system_prompt_generator import SystemPromptGenerator | |
# Initialize OpenAI client | |
client = instructor.from_openai(openai.OpenAI()) | |
############################## | |
# AGENT INPUT/OUTPUT SCHEMAS # | |
############################## | |
class MermaidAgentInputSchema(BaseIOSchema): | |
"""Input schema for the MermaidAgent.""" | |
prompt: str = Field( | |
..., | |
description="A description of the flowchart to be created using MermaidJS.", | |
) | |
class MermaidAgentOutputSchema(BaseIOSchema): | |
"""Output schema for the MermaidAgent.""" | |
internal_reasoning: list[str] = Field( | |
..., description="Internal reasoning steps for the flowchart creation." | |
) | |
mermaid_flowchart: str = Field( | |
..., | |
description="The generated MermaidJS schema for the flowchart. This schema must be rendered using MermaidJS and be completely compatible with the MermaidJS library.", | |
) | |
message_to_user: str = Field( | |
..., description="A message to the user about the flowchart creation." | |
) | |
############################# | |
# MERMAID AGENT INITIALIZER # | |
############################# | |
def initialize_mermaid_agent(): | |
return BaseAgent( | |
BaseAgentConfig( | |
client=client, | |
model="gpt-4o-mini", | |
system_prompt_generator=SystemPromptGenerator( | |
background=[ | |
"You are an AI assistant specialized in generating MermaidJS flowchart schemas.", | |
"You have extensive knowledge of MermaidJS syntax and flowchart types.", | |
], | |
steps=[ | |
"You will receive a description of a flowchart to be created.", | |
"Analyze the description and determine the most appropriate MermaidJS flowchart type.", | |
"Generate the MermaidJS schemas for the described flowchart.", | |
], | |
output_instructions=[ | |
"Provide only the MermaidJS schemas without any additional explanation.", | |
"Ensure the schema is valid and follows MermaidJS syntax. Do not include anything that is not related to the flowchart creation.", | |
"Include appropriate styling and layout options to enhance readability.", | |
], | |
), | |
input_schema=MermaidAgentInputSchema, | |
output_schema=MermaidAgentOutputSchema, | |
) | |
) | |
############################## | |
# MERMAID RENDERING FUNCTION # | |
############################## | |
def render_mermaid(mermaid_flowchart): | |
components.html( | |
f""" | |
<link rel="stylesheet" href="https://cdnjs.cloudflare.com/ajax/libs/font-awesome/5.15.3/css/all.min.css"> | |
<script src="https://cdn.jsdelivr.net/npm/mermaid@latest/dist/mermaid.min.js"></script> | |
<div class="mermaid-container" style="overflow-y: auto; max-height: 750px;"> | |
<div class="mermaid"> | |
{mermaid_flowchart} | |
</div> | |
</div> | |
<script> | |
mermaid.initialize({{ | |
startOnLoad: true, | |
fontFamily: 'monospace, sans-serif', | |
flowchart: {{ | |
htmlLabels: true, | |
useMaxWidth: true, | |
}}, | |
securityLevel: 'loose', | |
}}); | |
mermaid.parseError = function(err, hash) {{ | |
console.error('Mermaid error:', err); | |
}}; | |
</script> | |
""", | |
height=750, | |
) | |
###################### | |
# MAIN APPLICATION # | |
###################### | |
def main(): | |
st.set_page_config(layout="wide", page_title="AI-Powered Flowchart Creator") | |
st.title("AI-Powered Flowchart Creator") | |
# Initialize MermaidAgent | |
if "mermaid_agent" not in st.session_state: | |
st.session_state.mermaid_agent = initialize_mermaid_agent() | |
# Initialize chat history | |
if "chat_history" not in st.session_state: | |
st.session_state.chat_history = [] | |
# Create two columns | |
col1, col2 = st.columns([3, 2]) | |
################################# | |
# LEFT COLUMN: CHAT INTERFACE # | |
################################# | |
with col1: | |
st.header("Chat with AI Flowchart Creator") | |
# Create a container for chat history | |
chat_container = st.container() | |
# User input at the bottom | |
user_prompt = st.chat_input("Enter your flowchart description:") | |
if user_prompt: | |
# Add user message to chat history | |
st.session_state.chat_history.append( | |
{"role": "user", "content": user_prompt} | |
) | |
with st.spinner("Generating flowchart..."): | |
response = st.session_state.mermaid_agent.run( | |
MermaidAgentInputSchema(prompt=user_prompt) | |
) | |
# Add AI message to chat history | |
st.session_state.chat_history.append( | |
{"role": "assistant", "content": response.message_to_user} | |
) | |
# Store the generated flowchart in session state | |
st.session_state.generated_flowchart = response.mermaid_flowchart | |
st.session_state.internal_reasoning = response.internal_reasoning | |
st.success("Flowchart generated successfully!") | |
# Display chat history | |
with chat_container: | |
for message in st.session_state.chat_history: | |
with st.chat_message(message["role"]): | |
st.write(message["content"]) | |
################################# | |
# RIGHT COLUMN: FLOWCHART VIEW # | |
################################# | |
with col2: | |
st.header("Generated Flowchart") | |
if "generated_flowchart" in st.session_state: | |
# Render Mermaid flowchart | |
render_mermaid(st.session_state.generated_flowchart) | |
# Collapsible Mermaid code | |
with st.expander("Mermaid Code", expanded=False): | |
st.code(st.session_state.generated_flowchart, language="mermaid") | |
# Collapsible internal reasoning | |
with st.expander("Internal Reasoning", expanded=False): | |
for step in st.session_state.internal_reasoning: | |
st.write(step) | |
else: | |
st.info("Generate a flowchart to see it displayed here.") | |
if __name__ == "__main__": | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment