Skip to content

Instantly share code, notes, and snippets.

@virattt
Created November 19, 2024 23:51
Show Gist options
  • Save virattt/9b4b792329f6a1dfd37f1758c979c908 to your computer and use it in GitHub Desktop.
Save virattt/9b4b792329f6a1dfd37f1758c979c908 to your computer and use it in GitHub Desktop.
hedge-fund-agent-team-v1-4.ipynb
Display the source blob
Display the rendered blob
Raw
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"provenance": [],
"authorship_tag": "ABX9TyMxDOZeeIz9VQ497kQOFQVG",
"include_colab_link": true
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
},
"language_info": {
"name": "python"
}
},
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
},
"source": [
"<a href=\"https://colab.research.google.com/gist/virattt/9b4b792329f6a1dfd37f1758c979c908/hedge-fund-agent-team-v1-4.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "markdown",
"source": [
"This notebook provides a tutorial on how to use multi-agents with LangGraph.\n",
"\n",
"Specifically, we use the **supervisor** pattern, where we have 1 supervisor agent and 3 analyst agents:\n",
"1. fundamental analyst\n",
"2. technical analyst\n",
"3. sentiment analyst\n",
"\n",
"This code will be a part of an evolving series.\n",
"\n",
"If you have any questions, please message me on X at [virattt](https://twitter.com/virattt)."
],
"metadata": {
"id": "Xp0Uq2g0uLxb"
}
},
{
"cell_type": "markdown",
"source": [
"# Setup"
],
"metadata": {
"id": "xYivxWv2b6SW"
}
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "unPP2JGqble1"
},
"outputs": [],
"source": [
"%%capture --no-stderr\n",
"%pip install -U langgraph langchain langchain_openai langchain_experimental langsmith pandas ta"
]
},
{
"cell_type": "code",
"source": [
"import getpass\n",
"import os\n",
"\n",
"\n",
"def _set_if_undefined(var: str):\n",
" if not os.environ.get(var):\n",
" os.environ[var] = getpass.getpass(f\"Please provide your {var}\")\n",
"\n",
"\n",
"_set_if_undefined(\"OPENAI_API_KEY\") # For the agent. Get from https://platform.openai.com\n",
"_set_if_undefined(\"FINANCIAL_DATASETS_API_KEY\") # For getting financial data. Get from https://financialdatasets.ai\n",
"_set_if_undefined(\"TAVILY_API_KEY\") # For surfing the web. Get from https://tavily.com"
],
"metadata": {
"id": "5zJ1jU9-b9WS"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"# Define agent tools"
],
"metadata": {
"id": "iADG-Tp3b--h"
}
},
{
"cell_type": "code",
"source": [
"from langchain_core.tools import tool\n",
"from typing import List, Dict, Optional, Union\n",
"import requests\n",
"import os\n",
"from typing import Dict, Union\n",
"from pydantic import BaseModel, Field\n",
"import requests\n",
"from langchain_core.tools import tool\n",
"\n",
"import pandas as pd\n",
"import ta\n",
"from datetime import datetime, timedelta\n",
"\n",
"class GetIncomeStatementsInput(BaseModel):\n",
" ticker: str = Field(..., description=\"The ticker of the stock.\")\n",
" period: str = Field(default=\"ttm\", description=\"The period of the income statements. Valid values are 'ttm', 'quarterly' or 'annual'.\")\n",
" limit: int = Field(default=10, description=\"The maximum number of income statements to return. Default is 10.\")\n",
"\n",
"@tool(\"get_income_statements\", args_schema=GetIncomeStatementsInput, return_direct=True)\n",
"def get_income_statements(ticker: str, period: str = \"ttm\", limit: int = 10) -> Union[Dict, str]:\n",
" \"\"\"\n",
" Get income statements for a ticker with specified period and limit.\n",
" \"\"\"\n",
" api_key = os.environ.get(\"FINANCIAL_DATASETS_API_KEY\")\n",
" if not api_key:\n",
" raise ValueError(\"Missing FINANCIAL_DATASETS_API_KEY.\")\n",
"\n",
" url = (\n",
" f'https://api.financialdatasets.ai/financials/income-statements'\n",
" f'?ticker={ticker}'\n",
" f'&period={period}'\n",
" f'&limit={limit}'\n",
" )\n",
"\n",
" try:\n",
" response = requests.get(url, headers={'X-API-Key': api_key})\n",
" return response.json()\n",
" except Exception as e:\n",
" return {\"ticker\": ticker, \"income_statements\": [], \"error\": str(e)}\n",
"\n",
"class GetBalanceSheetsInput(BaseModel):\n",
" ticker: str = Field(..., description=\"The ticker of the stock.\")\n",
" period: str = Field(default=\"ttm\", description=\"The period of the balance sheets. Valid values are 'ttm', 'quarterly' or 'annual'.\")\n",
" limit: int = Field(default=10, description=\"The maximum number of balance sheets to return. Default is 10.\")\n",
"\n",
"@tool(\"get_balance_sheets\", args_schema=GetBalanceSheetsInput, return_direct=True)\n",
"def get_balance_sheets(ticker: str, period: str = \"ttm\", limit: int = 10) -> Union[Dict, str]:\n",
" \"\"\"\n",
" Get balance sheets for a ticker with specified period and limit.\n",
" \"\"\"\n",
" api_key = os.environ.get(\"FINANCIAL_DATASETS_API_KEY\")\n",
" if not api_key:\n",
" raise ValueError(\"Missing FINANCIAL_DATASETS_API_KEY.\")\n",
"\n",
" url = (\n",
" f'https://api.financialdatasets.ai/financials/balance-sheets'\n",
" f'?ticker={ticker}'\n",
" f'&period={period}'\n",
" f'&limit={limit}'\n",
" )\n",
"\n",
" try:\n",
" response = requests.get(url, headers={'X-API-Key': api_key})\n",
" return response.json()\n",
" except Exception as e:\n",
" return {\"ticker\": ticker, \"balance_sheets\": [], \"error\": str(e)}\n",
"\n",
"class GetCashFlowStatementsInput(BaseModel):\n",
" ticker: str = Field(..., description=\"The ticker of the stock.\")\n",
" period: str = Field(default=\"ttm\", description=\"The period of the cash flow statements. Valid values are 'ttm', 'quarterly' or 'annual'.\")\n",
" limit: int = Field(default=10, description=\"The maximum number of cash flow statements to return. Default is 10.\")\n",
"\n",
"@tool(\"get_cash_flow_statements\", args_schema=GetCashFlowStatementsInput, return_direct=True)\n",
"def get_cash_flow_statements(ticker: str, period: str = \"ttm\", limit: int = 10) -> Union[Dict, str]:\n",
" \"\"\"\n",
" Get cash flow statements for a ticker with specified period and limit.\n",
" \"\"\"\n",
" api_key = os.environ.get(\"FINANCIAL_DATASETS_API_KEY\")\n",
" if not api_key:\n",
" raise ValueError(\"Missing FINANCIAL_DATASETS_API_KEY.\")\n",
"\n",
" url = (\n",
" f'https://api.financialdatasets.ai/financials/cash-flow-statements'\n",
" f'?ticker={ticker}'\n",
" f'&period={period}'\n",
" f'&limit={limit}'\n",
" )\n",
"\n",
" try:\n",
" response = requests.get(url, headers={'X-API-Key': api_key})\n",
" return response.json()\n",
" except Exception as e:\n",
" return {\"ticker\": ticker, \"cash_flow_statements\": [], \"error\": str(e)}\n",
"\n",
"class GetPricesInput(BaseModel):\n",
" ticker: str = Field(..., description=\"The ticker of the stock.\")\n",
" start_date: str = Field(..., description=\"The start of the price time window. Either a date with the format YYYY-MM-DD or a millisecond timestamp.\")\n",
" end_date: str = Field(..., description=\"The end of the aggregate time window. Either a date with the format YYYY-MM-DD or a millisecond timestamp.\")\n",
" interval: str = Field(default=\"day\", description=\"The time interval of the prices. Valid values are second', 'minute', 'day', 'week', 'month', 'quarter', 'year'.\")\n",
" interval_multiplier: int = Field(default=1, description=\"The multiplier for the interval. For example, if interval is 'day' and interval_multiplier is 1, the prices will be daily. If interval is 'minute' and interval_multiplier is 5, the prices will be every 5 minutes.\")\n",
" limit: int = Field(default=5000, description=\"The maximum number of prices to return. The default is 5000 and the maximum is 50000.\")\n",
"\n",
"@tool(\"get_stock_prices\", args_schema=GetPricesInput, return_direct=True)\n",
"def get_stock_prices(ticker: str, start_date: str, end_date: str, interval: str = 'day', interval_multiplier: int = 1, limit: int = 5000) -> Union[Dict, str]:\n",
" \"\"\"\n",
" Get prices for a ticker over a given date range and interval.\n",
" \"\"\"\n",
"\n",
" api_key = os.environ.get(\"FINANCIAL_DATASETS_API_KEY\")\n",
" if not api_key:\n",
" raise ValueError(\"Missing FINANCIAL_DATASETS_API_KEY.\")\n",
" url = (\n",
" f\"https://api.financialdatasets.ai/prices\"\n",
" f\"?ticker={ticker}\"\n",
" f\"&start_date={start_date}\"\n",
" f\"&end_date={end_date}\"\n",
" f\"&interval={interval}\"\n",
" f\"&interval_multiplier={interval_multiplier}\"\n",
" f\"&limit={limit}\"\n",
" )\n",
"\n",
" try:\n",
" response = requests.get(url, headers={'X-API-Key': api_key})\n",
" data = response.json()\n",
" return data\n",
" except Exception as e:\n",
" return {\"ticker\": ticker, \"prices\": [], \"error\": str(e)}\n",
"\n",
"class GetCurrentPriceInput(BaseModel):\n",
" ticker: str = Field(..., description=\"The ticker of the stock.\")\n",
"\n",
"@tool(\"get_current_stock_price\", args_schema=GetCurrentPriceInput, return_direct=True)\n",
"def get_current_stock_price(ticker: str) -> Union[Dict, str]:\n",
" \"\"\"\n",
" Get the current (latest) stock price for a ticker.\n",
" \"\"\"\n",
" api_key = os.environ.get(\"FINANCIAL_DATASETS_API_KEY\")\n",
" if not api_key:\n",
" raise ValueError(\"Missing FINANCIAL_DATASETS_API_KEY.\")\n",
"\n",
" url = f\"https://api.financialdatasets.ai/prices/snapshot?ticker={ticker}\"\n",
"\n",
" try:\n",
" response = requests.get(url, headers={'X-API-Key': api_key})\n",
" return response.json()\n",
" except Exception as e:\n",
" return {\"ticker\": ticker, \"price\": None, \"error\": str(e)}\n",
"\n",
"class GetOptionsChainInput(BaseModel):\n",
" ticker: str = Field(..., description=\"The ticker of the stock.\")\n",
" limit: int = Field(default=10, description=\"The maximum number of options to return. Default is 10.\")\n",
" strike_price: Optional[float] = Field(default=None, description=\"Optional filter for specific strike price.\")\n",
" option_type: Optional[str] = Field(default=None, description=\"Optional filter for option type. Valid values are 'call' or 'put'.\")\n",
"\n",
"@tool(\"get_options_chain\", args_schema=GetOptionsChainInput, return_direct=True)\n",
"def get_options_chain(\n",
" ticker: str,\n",
" limit: int = 10,\n",
" strike_price: Optional[float] = None,\n",
" option_type: Optional[str] = None\n",
") -> Union[Dict, str]:\n",
" \"\"\"\n",
" Get options chain data for a ticker with optional filters for strike price and option type.\n",
" \"\"\"\n",
" api_key = os.environ.get(\"FINANCIAL_DATASETS_API_KEY\")\n",
" if not api_key:\n",
" raise ValueError(\"Missing FINANCIAL_DATASETS_API_KEY.\")\n",
"\n",
" params = {\n",
" 'ticker': ticker,\n",
" 'limit': limit\n",
" }\n",
"\n",
" if strike_price is not None:\n",
" params['strike_price'] = strike_price\n",
" if option_type is not None:\n",
" params['option_type'] = option_type\n",
"\n",
" url = 'https://api.financialdatasets.ai/options/chain'\n",
"\n",
" try:\n",
" response = requests.get(url, headers={'X-API-Key': api_key}, params=params)\n",
" return response.json()\n",
" except Exception as e:\n",
" return {\"ticker\": ticker, \"options_chain\": [], \"error\": str(e)}\n",
"\n",
"class GetInsiderTradesInput(BaseModel):\n",
" ticker: str = Field(..., description=\"The ticker of the stock.\")\n",
" limit: int = Field(default=10, description=\"The maximum number of insider transactions to return. Default is 10.\")\n",
"\n",
"@tool(\"get_insider_trades\", args_schema=GetInsiderTradesInput, return_direct=True)\n",
"def get_insider_trades(ticker: str, limit: int = 10) -> Union[Dict, str]:\n",
" \"\"\"\n",
" Get insider trading transactions for a ticker.\n",
" \"\"\"\n",
" api_key = os.environ.get(\"FINANCIAL_DATASETS_API_KEY\")\n",
" if not api_key:\n",
" raise ValueError(\"Missing FINANCIAL_DATASETS_API_KEY.\")\n",
"\n",
" url = (\n",
" f'https://api.financialdatasets.ai/insider-transactions'\n",
" f'?ticker={ticker}'\n",
" f'&limit={limit}'\n",
" )\n",
"\n",
" try:\n",
" response = requests.get(url, headers={'X-API-Key': api_key})\n",
" return response.json()\n",
" except Exception as e:\n",
" return {\"ticker\": ticker, \"insider_transactions\": [], \"error\": str(e)}\n",
"\n",
"class GetTechnicalIndicatorsInput(BaseModel):\n",
" \"\"\"Input schema for technical indicators calculation.\"\"\"\n",
" ticker: str = Field(..., description=\"The ticker of the stock.\")\n",
" indicator: str = Field(..., description=\"The technical indicator to calculate. Valid values are 'rsi', 'macd', 'sma', 'ema', 'bbands'.\")\n",
" period: Optional[int] = Field(default=14, description=\"The period for the indicator calculation. Default is 14.\")\n",
" start_date: Optional[str] = Field(default=None, description=\"Start date in YYYY-MM-DD format.\")\n",
" end_date: Optional[str] = Field(default=None, description=\"End date in YYYY-MM-DD format.\")\n",
" interval: Optional[str] = Field(default=\"day\", description=\"The time interval for price data.\")\n",
" interval_multiplier: Optional[int] = Field(default=1, description=\"Multiplier for the time interval.\")\n",
"\n",
"@tool(\"get_technical_indicators\", args_schema=GetTechnicalIndicatorsInput)\n",
"def get_technical_indicators(\n",
" ticker: str,\n",
" indicator: str,\n",
" period: int = 14,\n",
" interval: str = \"day\",\n",
" interval_multiplier: int = 1,\n",
" start_date: Optional[str] = None,\n",
" end_date: Optional[str] = None,\n",
") -> Union[Dict, str]:\n",
" \"\"\"\n",
" Calculate technical indicators for a given ticker and time period.\n",
" Supports RSI, MACD, SMA, EMA, and Bollinger Bands calculations.\n",
" \"\"\"\n",
" try:\n",
" # Fetch historical price data with padding for calculations\n",
" adjusted_start = (datetime.strptime(start_date, \"%Y-%m-%d\") - timedelta(days=period * 2)).strftime(\"%Y-%m-%d\")\n",
"\n",
" price_data = get_stock_prices.invoke({\n",
" \"ticker\": ticker,\n",
" \"start_date\": adjusted_start,\n",
" \"end_date\": end_date,\n",
" \"interval\": interval,\n",
" \"interval_multiplier\": interval_multiplier\n",
" })\n",
"\n",
" if \"error\" in price_data:\n",
" return price_data\n",
"\n",
" # Convert to pandas DataFrame with proper datetime handling\n",
" df = pd.DataFrame(price_data[\"prices\"])\n",
"\n",
" # Clean datetime strings by removing timezone\n",
" df['time'] = df['time'].apply(lambda x: x.split(' EDT')[0].split(' EST')[0])\n",
" # Convert to datetime after cleaning\n",
" df['time'] = pd.to_datetime(df['time'])\n",
" df.set_index('time', inplace=True)\n",
"\n",
" result = {\n",
" \"ticker\": ticker,\n",
" \"indicator\": indicator,\n",
" \"period\": period,\n",
" \"data\": []\n",
" }\n",
"\n",
" # Calculate indicators (no changes here)\n",
" if indicator.lower() == \"rsi\":\n",
" rsi = ta.momentum.RSIIndicator(df['close'], window=period)\n",
" df['indicator_value'] = rsi.rsi()\n",
" elif indicator.lower() == \"macd\":\n",
" macd = ta.trend.MACD(\n",
" df['close'],\n",
" window_slow=26,\n",
" window_fast=12,\n",
" window_sign=9\n",
" )\n",
" df['macd_line'] = macd.macd()\n",
" df['signal_line'] = macd.macd_signal()\n",
" df['histogram'] = macd.macd_diff()\n",
" df['indicator_value'] = df['macd_line']\n",
" elif indicator.lower() == \"sma\":\n",
" df['indicator_value'] = ta.trend.SMAIndicator(\n",
" df['close'],\n",
" window=period\n",
" ).sma_indicator()\n",
" elif indicator.lower() == \"ema\":\n",
" df['indicator_value'] = ta.trend.EMAIndicator(\n",
" df['close'],\n",
" window=period\n",
" ).ema_indicator()\n",
" elif indicator.lower() == \"bbands\":\n",
" bb = ta.volatility.BollingerBands(\n",
" df['close'],\n",
" window=period,\n",
" window_dev=2\n",
" )\n",
" df['middle_band'] = bb.bollinger_mavg()\n",
" df['upper_band'] = bb.bollinger_hband()\n",
" df['lower_band'] = bb.bollinger_lband()\n",
" df['indicator_value'] = df['middle_band']\n",
"\n",
" # Filter to requested date range\n",
" df = df[start_date:end_date]\n",
"\n",
" # Handle NaN values using newer pandas methods\n",
" df = df.ffill().bfill() # Using newer methods instead of fillna(method=...)\n",
"\n",
" for idx, row in df.iterrows():\n",
" data_point = {\n",
" \"time\": idx.strftime(\"%Y-%m-%d %H:%M:%S\"), # Clean datetime format\n",
" \"time_milliseconds\": int(idx.timestamp() * 1000),\n",
" \"value\": float(row['indicator_value'])\n",
" }\n",
"\n",
" if indicator.lower() == \"macd\":\n",
" data_point.update({\n",
" \"signal_line\": float(row['signal_line']),\n",
" \"histogram\": float(row['histogram'])\n",
" })\n",
" elif indicator.lower() == \"bbands\":\n",
" data_point.update({\n",
" \"upper_band\": float(row['upper_band']),\n",
" \"lower_band\": float(row['lower_band'])\n",
" })\n",
"\n",
" result[\"data\"].append(data_point)\n",
"\n",
" return result\n",
"\n",
" except Exception as e:\n",
" return {\n",
" \"ticker\": ticker,\n",
" \"indicator\": indicator,\n",
" \"error\": f\"Error calculating {indicator}: {str(e)}\"\n",
" }\n",
"\n",
"class GetFinancialMetricsInput(BaseModel):\n",
" ticker: str = Field(..., description=\"The ticker of the stock.\")\n",
"\n",
"@tool(\"get_financial_metrics\", args_schema=GetFinancialMetricsInput, return_direct=True)\n",
"def get_financial_metrics(ticker: str) -> Union[Dict, str]:\n",
" \"\"\"\n",
" Get key financial metrics snapshot for a ticker, including valuation ratios,\n",
" profitability margins, returns, and growth metrics.\n",
" \"\"\"\n",
" api_key = os.environ.get(\"FINANCIAL_DATASETS_API_KEY\")\n",
" if not api_key:\n",
" raise ValueError(\"Missing FINANCIAL_DATASETS_API_KEY.\")\n",
"\n",
" url = f\"https://api.financialdatasets.ai/financial-metrics/snapshot?ticker={ticker}\"\n",
"\n",
" try:\n",
" response = requests.get(url, headers={\"X-API-Key\": api_key})\n",
" return response.json()\n",
" except Exception as e:\n",
" return {\"ticker\": ticker, \"snapshot\": None, \"error\": str(e)}\n"
],
"metadata": {
"id": "twLVNqHMb_w9"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"from langchain_community.tools.tavily_search import TavilySearchResults\n",
"\n",
"# News tool\n",
"get_news_tool = TavilySearchResults(max_results=5)"
],
"metadata": {
"id": "OQ650p7nM6ad"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"# Team 1: Traditional Analysis Track\n",
"fundamental_analyst_tools = [\n",
" get_income_statements,\n",
" get_balance_sheets,\n",
" get_cash_flow_statements,\n",
" get_financial_metrics\n",
"]\n",
"\n",
"technical_analyst_tools = [\n",
" get_stock_prices,\n",
" get_current_stock_price,\n",
" get_technical_indicators\n",
"]\n",
"\n",
"sentiment_analyst_tools = [\n",
" get_options_chain,\n",
" get_insider_trades,\n",
" get_news_tool\n",
"]\n",
"\n",
"# Team 2: Specialized Analysis Track\n",
"quant_strategist_tools = [\n",
" get_stock_prices,\n",
" get_technical_indicators,\n",
" get_financial_metrics\n",
"]\n",
"\n",
"macro_analyst_tools = [\n",
" get_financial_metrics,\n",
" get_news_tool,\n",
" get_technical_indicators\n",
"]\n",
"\n",
"event_driven_analyst_tools = [\n",
" get_news_tool,\n",
" get_insider_trades,\n",
" get_financial_metrics,\n",
" get_current_stock_price\n",
"]\n",
"\n",
"derivative_analyst_tools = [\n",
" get_options_chain,\n",
" get_technical_indicators,\n",
" get_current_stock_price\n",
"]"
],
"metadata": {
"id": "ntBR8eulNR72"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"# Helper functions"
],
"metadata": {
"id": "F0N8OtbbcG0L"
}
},
{
"cell_type": "code",
"source": [
"# from langchain_core.messages import HumanMessage\n",
"\n",
"# def agent_node(state, agent, name):\n",
"# result = agent.invoke(state)\n",
"# return {\n",
"# \"messages\": [HumanMessage(content=result[\"messages\"][-1].content, name=name)]\n",
"# }"
],
"metadata": {
"id": "I7LhGDuEcCpC"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"# Create LangGraph"
],
"metadata": {
"id": "v5LtjrI-cJ1P"
}
},
{
"cell_type": "code",
"source": [
"from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder\n",
"from langchain_openai import ChatOpenAI\n",
"from pydantic import BaseModel\n",
"from typing import Literal, Sequence, List, Annotated\n",
"from typing_extensions import TypedDict\n",
"import functools\n",
"import operator\n",
"from langchain_core.messages import BaseMessage, HumanMessage, SystemMessage\n",
"from langgraph.graph import END, StateGraph, START\n",
"from langgraph.prebuilt import create_react_agent\n",
"\n",
"# Define team members\n",
"members = [\"quant_strategist\", \"macro_analyst\", \"event_driven_analyst\", \"derivative_analyst\"]\n",
"\n",
"# Create routing prompt template\n",
"routing_prompt = ChatPromptTemplate.from_messages([\n",
" (\n",
" \"system\",\n",
" \"You are a portfolio manager supervising a hedge fund team with the following analysts:\"\n",
" \"\\n- quant_strategist: Analyzes price patterns, technical indicators, and financial metrics using quantitative methods\"\n",
" \"\\n- macro_analyst: Analyzes broad market trends, economic factors, and their impact on financial metrics\"\n",
" \"\\n- event_driven_analyst: Analyzes special situations, corporate events, insider activity, and real-time price movements\"\n",
" \"\\n- derivative_analyst: Analyzes options flow, volatility patterns, and derivative pricing\"\n",
" \"\\nDetermine which analyst(s) should analyze the request. Respond with ONLY the analyst names\"\n",
" \" separated by commas (e.g., 'quant_strategist,macro_analyst'). Choose analysts based on:\"\n",
" \"\\n- Use quant_strategist for questions about statistical analysis, factor modeling, or technical patterns\"\n",
" \"\\n- Use macro_analyst for questions about economic trends, market-wide impacts, or sector analysis\"\n",
" \"\\n- Use event_driven_analyst for questions about corporate events, news impact, or insider activity\"\n",
" \"\\n- Use derivative_analyst for questions about options activity, volatility, or derivative strategies\"\n",
" ),\n",
" MessagesPlaceholder(variable_name=\"messages\"),\n",
"])\n",
"\n",
"# Create the summary prompt template\n",
"summary_prompt = ChatPromptTemplate.from_messages(\n",
" [\n",
" (\n",
" \"system\",\n",
" \"You are a portfolio manager responsible for synthesizing analysis from your team of analysts. \"\n",
" \"Review all the analysts' reports and provide a comprehensive summary including:\\n\"\n",
" \"1. Quantitative and statistical insights (when available)\\n\"\n",
" \"2. Macro and market trend analysis (when available)\\n\"\n",
" \"3. Event-driven factors and news impact (when available)\\n\"\n",
" \"4. Derivatives and volatility analysis (when available)\\n\"\n",
" \"5. Overall investment recommendation\\n\"\n",
" \"Make sure to highlight any discrepancies or conflicting signals between different analyses.\"\n",
" ),\n",
" MessagesPlaceholder(variable_name=\"messages\"),\n",
" (\n",
" \"human\",\n",
" \"Based on all the analyst reports above, provide a comprehensive summary and investment recommendation.\"\n",
" ),\n",
" ]\n",
")\n",
"\n",
"# Initialize LLM\n",
"llm = ChatOpenAI(model=\"gpt-4\")\n",
"\n",
"class AgentState(TypedDict):\n",
" messages: Annotated[Sequence[BaseMessage], operator.add]\n",
" selected_analysts: List[str]\n",
" current_analyst_idx: int\n",
"\n",
"def supervisor_router(state):\n",
" \"\"\"Route to appropriate analyst(s) based on the query\"\"\"\n",
" # Create the routing chain\n",
" routing_chain = routing_prompt | llm\n",
"\n",
" # Get the routing decision\n",
" result = routing_chain.invoke(state)\n",
" selected_analysts = [a.strip() for a in result.content.strip().split(',')]\n",
"\n",
" # Add routing message to state\n",
" message = SystemMessage(\n",
" content=f\"Routing query to: {', '.join(selected_analysts)}\",\n",
" name=\"supervisor\"\n",
" )\n",
"\n",
" return {\n",
" \"messages\": state[\"messages\"] + [message],\n",
" \"selected_analysts\": selected_analysts,\n",
" \"current_analyst_idx\": 0\n",
" }\n",
"\n",
"def get_next_step(state):\n",
" \"\"\"Determine the next step in the workflow\"\"\"\n",
" if not state[\"selected_analysts\"]:\n",
" return \"final_summary\"\n",
"\n",
" current_idx = state[\"current_analyst_idx\"]\n",
" if current_idx >= len(state[\"selected_analysts\"]):\n",
" return \"final_summary\"\n",
"\n",
" return state[\"selected_analysts\"][current_idx]\n",
"\n",
"def agent_node(state, agent, name):\n",
" \"\"\"Generic analyst node that updates the current_analyst_idx after completion\"\"\"\n",
" result = agent.invoke(state)\n",
"\n",
" return {\n",
" \"messages\": [HumanMessage(content=result[\"messages\"][-1].content, name=name)],\n",
" \"selected_analysts\": state[\"selected_analysts\"],\n",
" \"current_analyst_idx\": state[\"current_analyst_idx\"] + 1\n",
" }\n",
"\n",
"def final_summary_agent(state):\n",
" \"\"\"Create final summary of all analyst reports\"\"\"\n",
" summary_chain = summary_prompt | llm\n",
" result = summary_chain.invoke(state)\n",
" return {\n",
" \"messages\": [HumanMessage(content=result.content, name=\"portfolio_manager\")],\n",
" \"selected_analysts\": state[\"selected_analysts\"],\n",
" \"current_analyst_idx\": state[\"current_analyst_idx\"]\n",
" }\n",
"\n",
"# Initialize workflow\n",
"workflow = StateGraph(AgentState)\n",
"\n",
"# Create the analysts with their specific tools\n",
"quant_strategist = create_react_agent(llm, tools=quant_strategist_tools)\n",
"quant_strategist_node = functools.partial(agent_node, agent=quant_strategist, name=\"quant_strategist\")\n",
"\n",
"macro_analyst = create_react_agent(llm, tools=macro_analyst_tools)\n",
"macro_analyst_node = functools.partial(agent_node, agent=macro_analyst, name=\"macro_analyst\")\n",
"\n",
"event_driven_analyst = create_react_agent(llm, tools=event_driven_analyst_tools)\n",
"event_driven_analyst_node = functools.partial(agent_node, agent=event_driven_analyst, name=\"event_driven_analyst\")\n",
"\n",
"derivative_analyst = create_react_agent(llm, tools=derivative_analyst_tools)\n",
"derivative_analyst_node = functools.partial(agent_node, agent=derivative_analyst, name=\"derivative_analyst\")\n",
"\n",
"# Add nodes\n",
"workflow.add_node(\"supervisor\", supervisor_router)\n",
"workflow.add_node(\"quant_strategist\", quant_strategist_node)\n",
"workflow.add_node(\"macro_analyst\", macro_analyst_node)\n",
"workflow.add_node(\"event_driven_analyst\", event_driven_analyst_node)\n",
"workflow.add_node(\"derivative_analyst\", derivative_analyst_node)\n",
"workflow.add_node(\"final_summary\", final_summary_agent)\n",
"\n",
"# Add conditional edges\n",
"workflow.add_conditional_edges(\n",
" \"supervisor\",\n",
" get_next_step,\n",
" {\n",
" \"quant_strategist\": \"quant_strategist\",\n",
" \"macro_analyst\": \"macro_analyst\",\n",
" \"event_driven_analyst\": \"event_driven_analyst\",\n",
" \"derivative_analyst\": \"derivative_analyst\",\n",
" \"final_summary\": \"final_summary\"\n",
" }\n",
")\n",
"\n",
"# Add conditional edges from each analyst back to the router function\n",
"for analyst in members:\n",
" workflow.add_conditional_edges(\n",
" analyst,\n",
" get_next_step,\n",
" {\n",
" \"quant_strategist\": \"quant_strategist\",\n",
" \"macro_analyst\": \"macro_analyst\",\n",
" \"event_driven_analyst\": \"event_driven_analyst\",\n",
" \"derivative_analyst\": \"derivative_analyst\",\n",
" \"final_summary\": \"final_summary\"\n",
" }\n",
" )\n",
"\n",
"# Add entry point and final edges\n",
"workflow.add_edge(START, \"supervisor\")\n",
"workflow.add_edge(\"final_summary\", END)\n",
"\n",
"# Compile the graph\n",
"graph = workflow.compile()"
],
"metadata": {
"id": "TT2AggDicQt6"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"# Code to pretty print Agent output"
],
"metadata": {
"id": "V7_AEtWz56-n"
}
},
{
"cell_type": "code",
"source": [
"from typing import Dict, Any\n",
"import json\n",
"import re\n",
"from langchain_core.messages import HumanMessage\n",
"from rich.console import Console\n",
"from rich.panel import Panel\n",
"from rich.text import Text\n",
"from rich.rule import Rule\n",
"\n",
"console = Console()\n",
"\n",
"def format_bold_text(content: str) -> Text:\n",
" \"\"\"Convert **text** to rich Text with bold formatting.\"\"\"\n",
" text = Text()\n",
" pattern = r'\\*\\*(.*?)\\*\\*'\n",
"\n",
" # Split the text by the bold markers\n",
" parts = re.split(pattern, content)\n",
"\n",
" # Alternate between regular and bold text\n",
" for i, part in enumerate(parts):\n",
" if i % 2 == 0:\n",
" text.append(part)\n",
" else:\n",
" text.append(part, style=\"bold\")\n",
"\n",
" return text\n",
"\n",
"def format_message_content(content: str) -> Union[str, Text]:\n",
" \"\"\"Format the message content, handling JSON and text with bold markers.\"\"\"\n",
" try:\n",
" # Try to parse as JSON for prettier formatting\n",
" data = json.loads(content)\n",
" return json.dumps(data, indent=2)\n",
" except:\n",
" # If not JSON, check for bold markers\n",
" if '**' in content:\n",
" return format_bold_text(content)\n",
" return content\n",
"\n",
"def format_agent_message(message: HumanMessage) -> Union[str, Text]:\n",
" \"\"\"Format a single agent message.\"\"\"\n",
" return format_message_content(message.content)\n",
"\n",
"def get_agent_title(agent: str, message: HumanMessage) -> str:\n",
" \"\"\"Get the title for the agent panel, with fallback handling.\"\"\"\n",
" base_title = agent.replace('_', ' ').title()\n",
"\n",
" if hasattr(message, 'name') and message.name is not None:\n",
" try:\n",
" return message.name.replace('_', ' ').title()\n",
" except:\n",
" return base_title\n",
" return base_title\n",
"\n",
"def print_step(step: Dict[str, Any]) -> None:\n",
" \"\"\"Pretty print a single step of the agent execution.\"\"\"\n",
" for agent, data in step.items():\n",
" # Handle supervisor steps\n",
" if 'next' in data:\n",
" next_agent = data['next']\n",
" text = Text()\n",
" text.append(\"Portfolio Manager \", style=\"bold magenta\")\n",
" text.append(\"assigns next task to \", style=\"white\")\n",
"\n",
" if next_agent == \"final_summary\":\n",
" text.append(\"FINAL SUMMARY\", style=\"bold yellow\")\n",
" elif next_agent == \"END\":\n",
" text.append(\"END\", style=\"bold red\")\n",
" else:\n",
" text.append(f\"{next_agent}\", style=\"bold green\")\n",
"\n",
" console.print(Panel(\n",
" text,\n",
" title=\"[bold blue]Supervision Step\",\n",
" border_style=\"blue\"\n",
" ))\n",
"\n",
" # Handle agent responses and final summary\n",
" if 'messages' in data:\n",
" message = data['messages'][0]\n",
" formatted_content = format_agent_message(message)\n",
"\n",
" if agent == \"final_summary\":\n",
" # Final summary formatting\n",
" console.print(Rule(style=\"yellow\", title=\"Portfolio Analysis\"))\n",
" console.print(Panel(\n",
" formatted_content,\n",
" title=\"[bold yellow]Investment Summary and Recommendation\",\n",
" border_style=\"yellow\",\n",
" padding=(1, 2)\n",
" ))\n",
" console.print(Rule(style=\"yellow\"))\n",
" else:\n",
" # Regular analyst reports\n",
" title = get_agent_title(agent, message)\n",
" console.print(Panel(\n",
" formatted_content,\n",
" title=f\"[bold blue]{title} Report\",\n",
" border_style=\"green\"\n",
" ))\n",
"\n",
"def stream_agent_execution(graph, input_data: Dict, config: Dict) -> None:\n",
" \"\"\"Stream and pretty print the agent execution.\"\"\"\n",
" console.print(\"\\n[bold blue]Starting Agent Execution...[/bold blue]\\n\")\n",
"\n",
" for step in graph.stream(input_data, config):\n",
" if \"__end__\" not in step:\n",
" print_step(step)\n",
" console.print(\"\\n\")\n",
"\n",
" console.print(\"[bold blue]Analysis Complete[/bold blue]\\n\")"
],
"metadata": {
"id": "t2E2mnnJ5LaN"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"# Run the Hedge Fund team"
],
"metadata": {
"id": "Y1_IZnAUTAHw"
}
},
{
"cell_type": "code",
"source": [
"input_data = {\n",
" \"messages\": [HumanMessage(content=\"What is AAPL's current price and latest revenue?\")]\n",
"}\n",
"config = {\"recursion_limit\": 10}\n",
"stream_agent_execution(graph, input_data, config)"
],
"metadata": {
"id": "gLUCOhL85Lip"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"input_data = {\n",
" \"messages\": [HumanMessage(content=\"What is AAPL's latest news?\")]\n",
"}\n",
"config = {\"recursion_limit\": 10}\n",
"stream_agent_execution(graph, input_data, config)"
],
"metadata": {
"id": "pYyfbkCLNmGD"
},
"execution_count": null,
"outputs": []
}
]
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment