Skip to content

Instantly share code, notes, and snippets.

@virattt
Last active February 16, 2025 00:18
Show Gist options
  • Save virattt/0e4c7740472177a327b61449c9af721d to your computer and use it in GitHub Desktop.
Save virattt/0e4c7740472177a327b61449c9af721d to your computer and use it in GitHub Desktop.
hedge-fund-agent-team-v1-3.ipynb
Display the source blob
Display the rendered blob
Raw
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"provenance": [],
"collapsed_sections": [
"xYivxWv2b6SW",
"iADG-Tp3b--h"
],
"authorship_tag": "ABX9TyMz1oA7xmBxOwLywCKfk/fR",
"include_colab_link": true
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
},
"language_info": {
"name": "python"
}
},
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
},
"source": [
"<a href=\"https://colab.research.google.com/gist/virattt/0e4c7740472177a327b61449c9af721d/hedge-fund-agent-team-v1-3.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "markdown",
"source": [
"This notebook provides a tutorial on how to use multi-agents with LangGraph.\n",
"\n",
"Specifically, we use the **supervisor** pattern, where we have 1 supervisor agent and 3 analyst agents:\n",
"1. fundamental analyst\n",
"2. technical analyst\n",
"3. sentiment analyst\n",
"\n",
"This code will be a part of an evolving series.\n",
"\n",
"If you have any questions, please message me on X at [virattt](https://twitter.com/virattt)."
],
"metadata": {
"id": "Xp0Uq2g0uLxb"
}
},
{
"cell_type": "markdown",
"source": [
"# Setup"
],
"metadata": {
"id": "xYivxWv2b6SW"
}
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "unPP2JGqble1"
},
"outputs": [],
"source": [
"%%capture --no-stderr\n",
"%pip install -U langgraph langchain langchain_openai langchain_experimental langsmith pandas ta"
]
},
{
"cell_type": "code",
"source": [
"import getpass\n",
"import os\n",
"\n",
"\n",
"def _set_if_undefined(var: str):\n",
" if not os.environ.get(var):\n",
" os.environ[var] = getpass.getpass(f\"Please provide your {var}\")\n",
"\n",
"\n",
"_set_if_undefined(\"OPENAI_API_KEY\") # For the agent. Get from https://platform.openai.com\n",
"_set_if_undefined(\"FINANCIAL_DATASETS_API_KEY\") # For getting financial data. Get from https://financialdatasets.ai\n",
"_set_if_undefined(\"TAVILY_API_KEY\") # For surfing the web. Get from https://tavily.com"
],
"metadata": {
"id": "5zJ1jU9-b9WS"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"# Define agent tools"
],
"metadata": {
"id": "iADG-Tp3b--h"
}
},
{
"cell_type": "code",
"source": [
"from langchain_core.tools import tool\n",
"from typing import List, Dict, Optional, Union\n",
"import requests\n",
"import os\n",
"from typing import Dict, Union\n",
"from pydantic import BaseModel, Field\n",
"import requests\n",
"from langchain_core.tools import tool\n",
"\n",
"import pandas as pd\n",
"import ta\n",
"from datetime import datetime, timedelta\n",
"\n",
"class GetIncomeStatementsInput(BaseModel):\n",
" ticker: str = Field(..., description=\"The ticker of the stock.\")\n",
" period: str = Field(default=\"ttm\", description=\"The period of the income statements. Valid values are 'ttm', 'quarterly' or 'annual'.\")\n",
" limit: int = Field(default=10, description=\"The maximum number of income statements to return. Default is 10.\")\n",
"\n",
"@tool(\"get_income_statements\", args_schema=GetIncomeStatementsInput, return_direct=True)\n",
"def get_income_statements(ticker: str, period: str = \"ttm\", limit: int = 10) -> Union[Dict, str]:\n",
" \"\"\"\n",
" Get income statements for a ticker with specified period and limit.\n",
" \"\"\"\n",
" api_key = os.environ.get(\"FINANCIAL_DATASETS_API_KEY\")\n",
" if not api_key:\n",
" raise ValueError(\"Missing FINANCIAL_DATASETS_API_KEY.\")\n",
"\n",
" url = (\n",
" f'https://api.financialdatasets.ai/financials/income-statements'\n",
" f'?ticker={ticker}'\n",
" f'&period={period}'\n",
" f'&limit={limit}'\n",
" )\n",
"\n",
" try:\n",
" response = requests.get(url, headers={'X-API-Key': api_key})\n",
" return response.json()\n",
" except Exception as e:\n",
" return {\"ticker\": ticker, \"income_statements\": [], \"error\": str(e)}\n",
"\n",
"class GetBalanceSheetsInput(BaseModel):\n",
" ticker: str = Field(..., description=\"The ticker of the stock.\")\n",
" period: str = Field(default=\"ttm\", description=\"The period of the balance sheets. Valid values are 'ttm', 'quarterly' or 'annual'.\")\n",
" limit: int = Field(default=10, description=\"The maximum number of balance sheets to return. Default is 10.\")\n",
"\n",
"@tool(\"get_balance_sheets\", args_schema=GetBalanceSheetsInput, return_direct=True)\n",
"def get_balance_sheets(ticker: str, period: str = \"ttm\", limit: int = 10) -> Union[Dict, str]:\n",
" \"\"\"\n",
" Get balance sheets for a ticker with specified period and limit.\n",
" \"\"\"\n",
" api_key = os.environ.get(\"FINANCIAL_DATASETS_API_KEY\")\n",
" if not api_key:\n",
" raise ValueError(\"Missing FINANCIAL_DATASETS_API_KEY.\")\n",
"\n",
" url = (\n",
" f'https://api.financialdatasets.ai/financials/balance-sheets'\n",
" f'?ticker={ticker}'\n",
" f'&period={period}'\n",
" f'&limit={limit}'\n",
" )\n",
"\n",
" try:\n",
" response = requests.get(url, headers={'X-API-Key': api_key})\n",
" return response.json()\n",
" except Exception as e:\n",
" return {\"ticker\": ticker, \"balance_sheets\": [], \"error\": str(e)}\n",
"\n",
"class GetCashFlowStatementsInput(BaseModel):\n",
" ticker: str = Field(..., description=\"The ticker of the stock.\")\n",
" period: str = Field(default=\"ttm\", description=\"The period of the cash flow statements. Valid values are 'ttm', 'quarterly' or 'annual'.\")\n",
" limit: int = Field(default=10, description=\"The maximum number of cash flow statements to return. Default is 10.\")\n",
"\n",
"@tool(\"get_cash_flow_statements\", args_schema=GetCashFlowStatementsInput, return_direct=True)\n",
"def get_cash_flow_statements(ticker: str, period: str = \"ttm\", limit: int = 10) -> Union[Dict, str]:\n",
" \"\"\"\n",
" Get cash flow statements for a ticker with specified period and limit.\n",
" \"\"\"\n",
" api_key = os.environ.get(\"FINANCIAL_DATASETS_API_KEY\")\n",
" if not api_key:\n",
" raise ValueError(\"Missing FINANCIAL_DATASETS_API_KEY.\")\n",
"\n",
" url = (\n",
" f'https://api.financialdatasets.ai/financials/cash-flow-statements'\n",
" f'?ticker={ticker}'\n",
" f'&period={period}'\n",
" f'&limit={limit}'\n",
" )\n",
"\n",
" try:\n",
" response = requests.get(url, headers={'X-API-Key': api_key})\n",
" return response.json()\n",
" except Exception as e:\n",
" return {\"ticker\": ticker, \"cash_flow_statements\": [], \"error\": str(e)}\n",
"\n",
"class GetPricesInput(BaseModel):\n",
" ticker: str = Field(..., description=\"The ticker of the stock.\")\n",
" start_date: str = Field(..., description=\"The start of the price time window. Either a date with the format YYYY-MM-DD or a millisecond timestamp.\")\n",
" end_date: str = Field(..., description=\"The end of the aggregate time window. Either a date with the format YYYY-MM-DD or a millisecond timestamp.\")\n",
" interval: str = Field(default=\"day\", description=\"The time interval of the prices. Valid values are second', 'minute', 'day', 'week', 'month', 'quarter', 'year'.\")\n",
" interval_multiplier: int = Field(default=1, description=\"The multiplier for the interval. For example, if interval is 'day' and interval_multiplier is 1, the prices will be daily. If interval is 'minute' and interval_multiplier is 5, the prices will be every 5 minutes.\")\n",
" limit: int = Field(default=5000, description=\"The maximum number of prices to return. The default is 5000 and the maximum is 50000.\")\n",
"\n",
"@tool(\"get_stock_prices\", args_schema=GetPricesInput, return_direct=True)\n",
"def get_stock_prices(ticker: str, start_date: str, end_date: str, interval: str = 'day', interval_multiplier: int = 1, limit: int = 5000) -> Union[Dict, str]:\n",
" \"\"\"\n",
" Get prices for a ticker over a given date range and interval.\n",
" \"\"\"\n",
"\n",
" api_key = os.environ.get(\"FINANCIAL_DATASETS_API_KEY\")\n",
" if not api_key:\n",
" raise ValueError(\"Missing FINANCIAL_DATASETS_API_KEY.\")\n",
" url = (\n",
" f\"https://api.financialdatasets.ai/prices\"\n",
" f\"?ticker={ticker}\"\n",
" f\"&start_date={start_date}\"\n",
" f\"&end_date={end_date}\"\n",
" f\"&interval={interval}\"\n",
" f\"&interval_multiplier={interval_multiplier}\"\n",
" f\"&limit={limit}\"\n",
" )\n",
"\n",
" try:\n",
" response = requests.get(url, headers={'X-API-Key': api_key})\n",
" data = response.json()\n",
" return data\n",
" except Exception as e:\n",
" return {\"ticker\": ticker, \"prices\": [], \"error\": str(e)}\n",
"\n",
"class GetCurrentPriceInput(BaseModel):\n",
" ticker: str = Field(..., description=\"The ticker of the stock.\")\n",
"\n",
"@tool(\"get_current_stock_price\", args_schema=GetCurrentPriceInput, return_direct=True)\n",
"def get_current_stock_price(ticker: str) -> Union[Dict, str]:\n",
" \"\"\"\n",
" Get the current (latest) stock price for a ticker.\n",
" \"\"\"\n",
" api_key = os.environ.get(\"FINANCIAL_DATASETS_API_KEY\")\n",
" if not api_key:\n",
" raise ValueError(\"Missing FINANCIAL_DATASETS_API_KEY.\")\n",
"\n",
" url = f\"https://api.financialdatasets.ai/prices/snapshot?ticker={ticker}\"\n",
"\n",
" try:\n",
" response = requests.get(url, headers={'X-API-Key': api_key})\n",
" return response.json()\n",
" except Exception as e:\n",
" return {\"ticker\": ticker, \"price\": None, \"error\": str(e)}\n",
"\n",
"class GetOptionsChainInput(BaseModel):\n",
" ticker: str = Field(..., description=\"The ticker of the stock.\")\n",
" limit: int = Field(default=10, description=\"The maximum number of options to return. Default is 10.\")\n",
" strike_price: Optional[float] = Field(default=None, description=\"Optional filter for specific strike price.\")\n",
" option_type: Optional[str] = Field(default=None, description=\"Optional filter for option type. Valid values are 'call' or 'put'.\")\n",
"\n",
"@tool(\"get_options_chain\", args_schema=GetOptionsChainInput, return_direct=True)\n",
"def get_options_chain(\n",
" ticker: str,\n",
" limit: int = 10,\n",
" strike_price: Optional[float] = None,\n",
" option_type: Optional[str] = None\n",
") -> Union[Dict, str]:\n",
" \"\"\"\n",
" Get options chain data for a ticker with optional filters for strike price and option type.\n",
" \"\"\"\n",
" api_key = os.environ.get(\"FINANCIAL_DATASETS_API_KEY\")\n",
" if not api_key:\n",
" raise ValueError(\"Missing FINANCIAL_DATASETS_API_KEY.\")\n",
"\n",
" params = {\n",
" 'ticker': ticker,\n",
" 'limit': limit\n",
" }\n",
"\n",
" if strike_price is not None:\n",
" params['strike_price'] = strike_price\n",
" if option_type is not None:\n",
" params['option_type'] = option_type\n",
"\n",
" url = 'https://api.financialdatasets.ai/options/chain'\n",
"\n",
" try:\n",
" response = requests.get(url, headers={'X-API-Key': api_key}, params=params)\n",
" return response.json()\n",
" except Exception as e:\n",
" return {\"ticker\": ticker, \"options_chain\": [], \"error\": str(e)}\n",
"\n",
"class GetInsiderTradesInput(BaseModel):\n",
" ticker: str = Field(..., description=\"The ticker of the stock.\")\n",
" limit: int = Field(default=10, description=\"The maximum number of insider transactions to return. Default is 10.\")\n",
"\n",
"@tool(\"get_insider_trades\", args_schema=GetInsiderTradesInput, return_direct=True)\n",
"def get_insider_trades(ticker: str, limit: int = 10) -> Union[Dict, str]:\n",
" \"\"\"\n",
" Get insider trading transactions for a ticker.\n",
" \"\"\"\n",
" api_key = os.environ.get(\"FINANCIAL_DATASETS_API_KEY\")\n",
" if not api_key:\n",
" raise ValueError(\"Missing FINANCIAL_DATASETS_API_KEY.\")\n",
"\n",
" url = (\n",
" f'https://api.financialdatasets.ai/insider-transactions'\n",
" f'?ticker={ticker}'\n",
" f'&limit={limit}'\n",
" )\n",
"\n",
" try:\n",
" response = requests.get(url, headers={'X-API-Key': api_key})\n",
" return response.json()\n",
" except Exception as e:\n",
" return {\"ticker\": ticker, \"insider_transactions\": [], \"error\": str(e)}\n",
"\n",
"class GetTechnicalIndicatorsInput(BaseModel):\n",
" \"\"\"Input schema for technical indicators calculation.\"\"\"\n",
" ticker: str = Field(..., description=\"The ticker of the stock.\")\n",
" indicator: str = Field(..., description=\"The technical indicator to calculate. Valid values are 'rsi', 'macd', 'sma', 'ema', 'bbands'.\")\n",
" period: Optional[int] = Field(default=14, description=\"The period for the indicator calculation. Default is 14.\")\n",
" start_date: Optional[str] = Field(default=None, description=\"Start date in YYYY-MM-DD format.\")\n",
" end_date: Optional[str] = Field(default=None, description=\"End date in YYYY-MM-DD format.\")\n",
" interval: Optional[str] = Field(default=\"day\", description=\"The time interval for price data.\")\n",
" interval_multiplier: Optional[int] = Field(default=1, description=\"Multiplier for the time interval.\")\n",
"\n",
"@tool(\"get_technical_indicators\", args_schema=GetTechnicalIndicatorsInput)\n",
"def get_technical_indicators(\n",
" ticker: str,\n",
" indicator: str,\n",
" period: int = 14,\n",
" interval: str = \"day\",\n",
" interval_multiplier: int = 1,\n",
" start_date: Optional[str] = None,\n",
" end_date: Optional[str] = None,\n",
") -> Union[Dict, str]:\n",
" \"\"\"\n",
" Calculate technical indicators for a given ticker and time period.\n",
" Supports RSI, MACD, SMA, EMA, and Bollinger Bands calculations.\n",
" \"\"\"\n",
" try:\n",
" # Fetch historical price data with padding for calculations\n",
" adjusted_start = (datetime.strptime(start_date, \"%Y-%m-%d\") - timedelta(days=period * 2)).strftime(\"%Y-%m-%d\")\n",
"\n",
" price_data = get_stock_prices.invoke({\n",
" \"ticker\": ticker,\n",
" \"start_date\": adjusted_start,\n",
" \"end_date\": end_date,\n",
" \"interval\": interval,\n",
" \"interval_multiplier\": interval_multiplier\n",
" })\n",
"\n",
" if \"error\" in price_data:\n",
" return price_data\n",
"\n",
" # Convert to pandas DataFrame with proper datetime handling\n",
" df = pd.DataFrame(price_data[\"prices\"])\n",
"\n",
" # Clean datetime strings by removing timezone\n",
" df['time'] = df['time'].apply(lambda x: x.split(' EDT')[0].split(' EST')[0])\n",
" # Convert to datetime after cleaning\n",
" df['time'] = pd.to_datetime(df['time'])\n",
" df.set_index('time', inplace=True)\n",
"\n",
" result = {\n",
" \"ticker\": ticker,\n",
" \"indicator\": indicator,\n",
" \"period\": period,\n",
" \"data\": []\n",
" }\n",
"\n",
" # Calculate indicators (no changes here)\n",
" if indicator.lower() == \"rsi\":\n",
" rsi = ta.momentum.RSIIndicator(df['close'], window=period)\n",
" df['indicator_value'] = rsi.rsi()\n",
" elif indicator.lower() == \"macd\":\n",
" macd = ta.trend.MACD(\n",
" df['close'],\n",
" window_slow=26,\n",
" window_fast=12,\n",
" window_sign=9\n",
" )\n",
" df['macd_line'] = macd.macd()\n",
" df['signal_line'] = macd.macd_signal()\n",
" df['histogram'] = macd.macd_diff()\n",
" df['indicator_value'] = df['macd_line']\n",
" elif indicator.lower() == \"sma\":\n",
" df['indicator_value'] = ta.trend.SMAIndicator(\n",
" df['close'],\n",
" window=period\n",
" ).sma_indicator()\n",
" elif indicator.lower() == \"ema\":\n",
" df['indicator_value'] = ta.trend.EMAIndicator(\n",
" df['close'],\n",
" window=period\n",
" ).ema_indicator()\n",
" elif indicator.lower() == \"bbands\":\n",
" bb = ta.volatility.BollingerBands(\n",
" df['close'],\n",
" window=period,\n",
" window_dev=2\n",
" )\n",
" df['middle_band'] = bb.bollinger_mavg()\n",
" df['upper_band'] = bb.bollinger_hband()\n",
" df['lower_band'] = bb.bollinger_lband()\n",
" df['indicator_value'] = df['middle_band']\n",
"\n",
" # Filter to requested date range\n",
" df = df[start_date:end_date]\n",
"\n",
" # Handle NaN values using newer pandas methods\n",
" df = df.ffill().bfill() # Using newer methods instead of fillna(method=...)\n",
"\n",
" for idx, row in df.iterrows():\n",
" data_point = {\n",
" \"time\": idx.strftime(\"%Y-%m-%d %H:%M:%S\"), # Clean datetime format\n",
" \"time_milliseconds\": int(idx.timestamp() * 1000),\n",
" \"value\": float(row['indicator_value'])\n",
" }\n",
"\n",
" if indicator.lower() == \"macd\":\n",
" data_point.update({\n",
" \"signal_line\": float(row['signal_line']),\n",
" \"histogram\": float(row['histogram'])\n",
" })\n",
" elif indicator.lower() == \"bbands\":\n",
" data_point.update({\n",
" \"upper_band\": float(row['upper_band']),\n",
" \"lower_band\": float(row['lower_band'])\n",
" })\n",
"\n",
" result[\"data\"].append(data_point)\n",
"\n",
" return result\n",
"\n",
" except Exception as e:\n",
" return {\n",
" \"ticker\": ticker,\n",
" \"indicator\": indicator,\n",
" \"error\": f\"Error calculating {indicator}: {str(e)}\"\n",
" }"
],
"metadata": {
"id": "twLVNqHMb_w9"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"from typing import Annotated\n",
"\n",
"from langchain_community.tools.tavily_search import TavilySearchResults\n",
"\n",
"# News tool\n",
"get_news_tool = TavilySearchResults(max_results=5)"
],
"metadata": {
"id": "OQ650p7nM6ad"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"# Group tools by analyst\n",
"fundamental_tools = [get_income_statements, get_balance_sheets, get_cash_flow_statements]\n",
"technical_tools = [get_stock_prices, get_current_stock_price, get_technical_indicators]\n",
"sentiment_tools = [get_options_chain, get_insider_trades, get_news_tool]"
],
"metadata": {
"id": "ntBR8eulNR72"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"# Helper functions"
],
"metadata": {
"id": "F0N8OtbbcG0L"
}
},
{
"cell_type": "code",
"source": [
"from langchain_core.messages import HumanMessage\n",
"\n",
"def agent_node(state, agent, name):\n",
" result = agent.invoke(state)\n",
" return {\n",
" \"messages\": [HumanMessage(content=result[\"messages\"][-1].content, name=name)]\n",
" }"
],
"metadata": {
"id": "I7LhGDuEcCpC"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"# Create LangGraph"
],
"metadata": {
"id": "v5LtjrI-cJ1P"
}
},
{
"cell_type": "code",
"source": [
"from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder\n",
"from langchain_openai import ChatOpenAI\n",
"from pydantic import BaseModel\n",
"from typing import Literal, Sequence, List\n",
"from typing_extensions import TypedDict\n",
"import functools\n",
"import operator\n",
"from langchain_core.messages import BaseMessage, HumanMessage, SystemMessage\n",
"from langgraph.graph import END, StateGraph, START\n",
"from langgraph.prebuilt import create_react_agent\n",
"\n",
"# Define team members\n",
"members = [\"fundamental_analyst\", \"technical_analyst\", \"sentiment_analyst\"]\n",
"\n",
"# Create routing prompt template\n",
"routing_prompt = ChatPromptTemplate.from_messages([\n",
" (\n",
" \"system\",\n",
" \"You are a portfolio manager supervising a hedge fund team with the following analysts:\"\n",
" \"\\n- fundamental_analyst: Analyzes financial statements and company health\"\n",
" \"\\n- technical_analyst: Analyzes price patterns and market trends\"\n",
" \"\\n- sentiment_analyst: Analyzes insider trading activity, options flow, and the news\"\n",
" \"\\nDetermine which analyst(s) should analyze the request. Respond with ONLY the analyst names\"\n",
" \" separated by commas (e.g., 'technical_analyst,sentiment_analyst'). Choose analysts based on:\"\n",
" \"\\n- Use fundamental_analyst for questions about financials, valuations, or company health\"\n",
" \"\\n- Use technical_analyst for questions about price action, trends, or chart patterns\"\n",
" \"\\n- Use sentiment_analyst for questions about market sentiment, news impact, or trading activity\"\n",
" ),\n",
" MessagesPlaceholder(variable_name=\"messages\"),\n",
"])\n",
"\n",
"# Create the summary prompt template\n",
"summary_prompt = ChatPromptTemplate.from_messages(\n",
" [\n",
" (\n",
" \"system\",\n",
" \"You are a portfolio manager responsible for synthesizing analysis from your team of analysts. \"\n",
" \"Review all the analysts' reports and provide a comprehensive summary including:\\n\"\n",
" \"1. Key financial metrics and their implications (only when you have this data) \\n\"\n",
" \"2. Technical analysis insights (only when you have this data) \\n\"\n",
" \"3. Market sentiment and news impact (only when you have this data) \\n\"\n",
" \"4. Overall investment recommendation\\n\"\n",
" \"Make sure to highlight any discrepancies or conflicting signals between different analyses.\"\n",
" ),\n",
" MessagesPlaceholder(variable_name=\"messages\"),\n",
" (\n",
" \"human\",\n",
" \"Based on all the analyst reports above, provide a comprehensive summary and investment recommendation.\"\n",
" ),\n",
" ]\n",
")\n",
"\n",
"# Initialize LLM\n",
"llm = ChatOpenAI(model=\"gpt-4\")\n",
"\n",
"class AgentState(TypedDict):\n",
" messages: Annotated[Sequence[BaseMessage], operator.add]\n",
" selected_analysts: List[str]\n",
" current_analyst_idx: int\n",
"\n",
"def supervisor_router(state):\n",
" \"\"\"Route to appropriate analyst(s) based on the query\"\"\"\n",
" # Create the routing chain\n",
" routing_chain = routing_prompt | llm\n",
"\n",
" # Get the routing decision\n",
" result = routing_chain.invoke(state)\n",
" selected_analysts = [a.strip() for a in result.content.strip().split(',')]\n",
"\n",
" # Add routing message to state\n",
" message = SystemMessage(\n",
" content=f\"Routing query to: {', '.join(selected_analysts)}\",\n",
" name=\"supervisor\"\n",
" )\n",
"\n",
" return {\n",
" \"messages\": state[\"messages\"] + [message],\n",
" \"selected_analysts\": selected_analysts,\n",
" \"current_analyst_idx\": 0\n",
" }\n",
"\n",
"def get_next_step(state):\n",
" \"\"\"Determine the next step in the workflow\"\"\"\n",
" if not state[\"selected_analysts\"]:\n",
" return \"final_summary\"\n",
"\n",
" current_idx = state[\"current_analyst_idx\"]\n",
" if current_idx >= len(state[\"selected_analysts\"]):\n",
" return \"final_summary\"\n",
"\n",
" return state[\"selected_analysts\"][current_idx]\n",
"\n",
"def agent_node(state, agent, name):\n",
" \"\"\"Generic analyst node that updates the current_analyst_idx after completion\"\"\"\n",
" result = agent.invoke(state)\n",
"\n",
" return {\n",
" \"messages\": [HumanMessage(content=result[\"messages\"][-1].content, name=name)],\n",
" \"selected_analysts\": state[\"selected_analysts\"],\n",
" \"current_analyst_idx\": state[\"current_analyst_idx\"] + 1\n",
" }\n",
"\n",
"def final_summary_agent(state):\n",
" \"\"\"Create final summary of all analyst reports\"\"\"\n",
" summary_chain = summary_prompt | llm\n",
" result = summary_chain.invoke(state)\n",
" return {\n",
" \"messages\": [HumanMessage(content=result.content, name=\"portfolio_manager\")],\n",
" \"selected_analysts\": state[\"selected_analysts\"],\n",
" \"current_analyst_idx\": state[\"current_analyst_idx\"]\n",
" }\n",
"\n",
"# Initialize workflow\n",
"workflow = StateGraph(AgentState)\n",
"\n",
"# Create the analysts\n",
"fundamental_analyst = create_react_agent(llm, tools=fundamental_tools)\n",
"fundamental_analyst_node = functools.partial(agent_node, agent=fundamental_analyst, name=\"fundamental_analyst\")\n",
"\n",
"technical_analyst = create_react_agent(llm, tools=technical_tools)\n",
"technical_analyst_node = functools.partial(agent_node, agent=technical_analyst, name=\"technical_analyst\")\n",
"\n",
"sentiment_analyst = create_react_agent(llm, tools=sentiment_tools)\n",
"sentiment_analyst_node = functools.partial(agent_node, agent=sentiment_analyst, name=\"sentiment_analyst\")\n",
"\n",
"# Add nodes\n",
"workflow.add_node(\"supervisor\", supervisor_router)\n",
"workflow.add_node(\"fundamental_analyst\", fundamental_analyst_node)\n",
"workflow.add_node(\"technical_analyst\", technical_analyst_node)\n",
"workflow.add_node(\"sentiment_analyst\", sentiment_analyst_node)\n",
"workflow.add_node(\"final_summary\", final_summary_agent)\n",
"\n",
"# Add conditional edges\n",
"workflow.add_conditional_edges(\n",
" \"supervisor\",\n",
" get_next_step,\n",
" {\n",
" \"fundamental_analyst\": \"fundamental_analyst\",\n",
" \"technical_analyst\": \"technical_analyst\",\n",
" \"sentiment_analyst\": \"sentiment_analyst\",\n",
" \"final_summary\": \"final_summary\"\n",
" }\n",
")\n",
"\n",
"# Add conditional edges from each analyst back to the router function\n",
"for analyst in members:\n",
" workflow.add_conditional_edges(\n",
" analyst,\n",
" get_next_step,\n",
" {\n",
" \"fundamental_analyst\": \"fundamental_analyst\",\n",
" \"technical_analyst\": \"technical_analyst\",\n",
" \"sentiment_analyst\": \"sentiment_analyst\",\n",
" \"final_summary\": \"final_summary\"\n",
" }\n",
" )\n",
"\n",
"# Add entry point and final edges\n",
"workflow.add_edge(START, \"supervisor\")\n",
"workflow.add_edge(\"final_summary\", END)\n",
"\n",
"# Compile the graph\n",
"graph = workflow.compile()"
],
"metadata": {
"id": "TT2AggDicQt6"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"\n",
"# Example usage\n",
"response = graph.invoke({\n",
" \"messages\": [HumanMessage(content=\"Analyze AAPL's recent price action and market sentiment\")],\n",
" \"next_analyst\": \"supervisor\"\n",
"})"
],
"metadata": {
"id": "_YI-F-ZM9iOQ"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"# Code to pretty print Agent output"
],
"metadata": {
"id": "V7_AEtWz56-n"
}
},
{
"cell_type": "code",
"source": [
"from typing import Dict, Any\n",
"import json\n",
"import re\n",
"from langchain_core.messages import HumanMessage\n",
"from rich.console import Console\n",
"from rich.panel import Panel\n",
"from rich.text import Text\n",
"from rich.rule import Rule\n",
"\n",
"console = Console()\n",
"\n",
"def format_bold_text(content: str) -> Text:\n",
" \"\"\"Convert **text** to rich Text with bold formatting.\"\"\"\n",
" text = Text()\n",
" pattern = r'\\*\\*(.*?)\\*\\*'\n",
"\n",
" # Split the text by the bold markers\n",
" parts = re.split(pattern, content)\n",
"\n",
" # Alternate between regular and bold text\n",
" for i, part in enumerate(parts):\n",
" if i % 2 == 0:\n",
" text.append(part)\n",
" else:\n",
" text.append(part, style=\"bold\")\n",
"\n",
" return text\n",
"\n",
"def format_message_content(content: str) -> Union[str, Text]:\n",
" \"\"\"Format the message content, handling JSON and text with bold markers.\"\"\"\n",
" try:\n",
" # Try to parse as JSON for prettier formatting\n",
" data = json.loads(content)\n",
" return json.dumps(data, indent=2)\n",
" except:\n",
" # If not JSON, check for bold markers\n",
" if '**' in content:\n",
" return format_bold_text(content)\n",
" return content\n",
"\n",
"def format_agent_message(message: HumanMessage) -> Union[str, Text]:\n",
" \"\"\"Format a single agent message.\"\"\"\n",
" return format_message_content(message.content)\n",
"\n",
"def get_agent_title(agent: str, message: HumanMessage) -> str:\n",
" \"\"\"Get the title for the agent panel, with fallback handling.\"\"\"\n",
" base_title = agent.replace('_', ' ').title()\n",
"\n",
" if hasattr(message, 'name') and message.name is not None:\n",
" try:\n",
" return message.name.replace('_', ' ').title()\n",
" except:\n",
" return base_title\n",
" return base_title\n",
"\n",
"def print_step(step: Dict[str, Any]) -> None:\n",
" \"\"\"Pretty print a single step of the agent execution.\"\"\"\n",
" for agent, data in step.items():\n",
" # Handle supervisor steps\n",
" if 'next' in data:\n",
" next_agent = data['next']\n",
" text = Text()\n",
" text.append(\"Portfolio Manager \", style=\"bold magenta\")\n",
" text.append(\"assigns next task to \", style=\"white\")\n",
"\n",
" if next_agent == \"final_summary\":\n",
" text.append(\"FINAL SUMMARY\", style=\"bold yellow\")\n",
" elif next_agent == \"END\":\n",
" text.append(\"END\", style=\"bold red\")\n",
" else:\n",
" text.append(f\"{next_agent}\", style=\"bold green\")\n",
"\n",
" console.print(Panel(\n",
" text,\n",
" title=\"[bold blue]Supervision Step\",\n",
" border_style=\"blue\"\n",
" ))\n",
"\n",
" # Handle agent responses and final summary\n",
" if 'messages' in data:\n",
" message = data['messages'][0]\n",
" formatted_content = format_agent_message(message)\n",
"\n",
" if agent == \"final_summary\":\n",
" # Final summary formatting\n",
" console.print(Rule(style=\"yellow\", title=\"Portfolio Analysis\"))\n",
" console.print(Panel(\n",
" formatted_content,\n",
" title=\"[bold yellow]Investment Summary and Recommendation\",\n",
" border_style=\"yellow\",\n",
" padding=(1, 2)\n",
" ))\n",
" console.print(Rule(style=\"yellow\"))\n",
" else:\n",
" # Regular analyst reports\n",
" title = get_agent_title(agent, message)\n",
" console.print(Panel(\n",
" formatted_content,\n",
" title=f\"[bold blue]{title} Report\",\n",
" border_style=\"green\"\n",
" ))\n",
"\n",
"def stream_agent_execution(graph, input_data: Dict, config: Dict) -> None:\n",
" \"\"\"Stream and pretty print the agent execution.\"\"\"\n",
" console.print(\"\\n[bold blue]Starting Agent Execution...[/bold blue]\\n\")\n",
"\n",
" for step in graph.stream(input_data, config):\n",
" if \"__end__\" not in step:\n",
" print_step(step)\n",
" console.print(\"\\n\")\n",
"\n",
" console.print(\"[bold blue]Analysis Complete[/bold blue]\\n\")"
],
"metadata": {
"id": "t2E2mnnJ5LaN"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"# Run the Hedge Fund team"
],
"metadata": {
"id": "Y1_IZnAUTAHw"
}
},
{
"cell_type": "code",
"source": [
"input_data = {\n",
" \"messages\": [HumanMessage(content=\"What is AAPL's current price and latest revenue?\")]\n",
"}\n",
"config = {\"recursion_limit\": 10}\n",
"stream_agent_execution(graph, input_data, config)"
],
"metadata": {
"id": "gLUCOhL85Lip"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"input_data = {\n",
" \"messages\": [HumanMessage(content=\"What is AAPL's latest news?\")]\n",
"}\n",
"config = {\"recursion_limit\": 10}\n",
"stream_agent_execution(graph, input_data, config)"
],
"metadata": {
"id": "pYyfbkCLNmGD"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [],
"metadata": {
"id": "6BFJLGik_oRu"
},
"execution_count": null,
"outputs": []
}
]
}
@mger1608
Copy link

Are you running this locally or is Google Colab sufficient to manage all of this? I'm not technical by background but have been trying to build more and more in this realm. I made some edits via a cloned notebook and then tried to start running locally but am running into issues managing packages.

Are you using a package manager if running locally or keeping everything in Google Colab?

@virattt
Copy link
Author

virattt commented Nov 13, 2024

@mger1608 I’m running everything in Colab!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment