Skip to content

Instantly share code, notes, and snippets.

@virattt
Created November 21, 2024 23:59
Show Gist options
  • Save virattt/b624560b05f50b69f1042a7a4f4adceb to your computer and use it in GitHub Desktop.
Save virattt/b624560b05f50b69f1042a7a4f4adceb to your computer and use it in GitHub Desktop.
sma-agent-backtesting.ipynb
Display the source blob
Display the rendered blob
Raw
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"provenance": [],
"authorship_tag": "ABX9TyO5xx25x39Ls8ItrGZh6PYz",
"include_colab_link": true
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
},
"language_info": {
"name": "python"
}
},
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
},
"source": [
"<a href=\"https://colab.research.google.com/gist/virattt/b624560b05f50b69f1042a7a4f4adceb/sma-agent-backtesting.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "jDTXcDmUj0bV"
},
"outputs": [],
"source": [
"!pip install -U --quiet langgraph langchain_openai"
]
},
{
"cell_type": "code",
"source": [
"import getpass\n",
"import os\n",
"\n",
"\n",
"def _set_if_undefined(var: str):\n",
" if not os.environ.get(var):\n",
" os.environ[var] = getpass.getpass(f\"Please provide your {var}\")\n",
"\n",
"\n",
"_set_if_undefined(\"FINANCIAL_DATASETS_API_KEY\") # For getting financial data. Get from https://financialdatasets.ai\n",
"_set_if_undefined(\"OPENAI_API_KEY\") # For getting financial data. Get from https://financialdatasets.ai"
],
"metadata": {
"id": "azDJXqYSl2sF"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"import pandas as pd\n",
"import requests\n",
"import os\n",
"from datetime import timedelta\n",
"import matplotlib.pyplot as plt\n",
"\n",
"# Import your agent's dependencies\n",
"from langchain_openai.chat_models import ChatOpenAI\n",
"from langchain_core.messages import HumanMessage, SystemMessage\n",
"from langgraph.graph import StateGraph, MessagesState"
],
"metadata": {
"id": "lP3IylZfm744"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"# 1. Create the Agent"
],
"metadata": {
"id": "-adeYrdHpHvx"
}
},
{
"cell_type": "code",
"source": [
"# Initialize the OpenAI model\n",
"gpt_4o_model = ChatOpenAI(model=\"gpt-4o\", temperature=0)\n",
"\n",
"# Update the system prompt to include portfolio management\n",
"system_prompt = \"\"\"\n",
"You are a financial trading agent using a Simple Moving Average (SMA) crossover strategy with portfolio management capabilities.\n",
"\n",
"Your responsibilities:\n",
"1. Analyze technical indicators (5-day and 20-day SMAs)\n",
"2. Manage portfolio allocation based on:\n",
" - Available cash\n",
" - Current stock position\n",
" - Current market price\n",
"3. Generate specific trading orders with:\n",
" - Action: 'buy', 'sell', or 'hold'\n",
" - Quantity: number of shares to trade (must be affordable for buys or available for sells)\n",
"\n",
"Base your decisions on:\n",
"- SMA crossover signals (5-day vs 20-day)\n",
"- Current portfolio state\n",
"- Risk management (don't use more than 50% of available cash in a single trade)\n",
"- Trading constraints:\n",
" * For buy orders: ensure quantity * current_price <= available cash\n",
" * For sell orders: ensure quantity <= current stock position\n",
"\n",
"Your response should be in the format:\n",
"{\n",
" \"action\": \"<buy|sell|hold>\",\n",
" \"quantity\": <number_of_shares>\n",
"}\n",
"Only output this JSON object, without any additional text.\n",
"\"\"\"\n",
"\n",
"\n",
"# Update the function that calls the model\n",
"def call_agent(state: MessagesState):\n",
" prompt = SystemMessage(content=system_prompt)\n",
" messages = state[\"messages\"]\n",
"\n",
" if messages and messages[0].content != system_prompt:\n",
" messages.insert(0, prompt)\n",
"\n",
" if hasattr(messages[-1], \"additional_kwargs\"):\n",
" params = messages[-1].additional_kwargs\n",
" historical_data = get_price_data(\n",
" params[\"ticker\"], params[\"start_date\"], params[\"end_date\"]\n",
" )\n",
" signals = calculate_trading_signals(historical_data)\n",
" portfolio = params.get(\"portfolio\", {\n",
" \"cash\": params.get(\"initial_capital\", 100000),\n",
" \"stock\": 0\n",
" })\n",
"\n",
" messages[-1].content = f\"\"\"\n",
" Current price: {signals['current_price']:.2f}\n",
" 5-day SMA: {signals['sma_5_curr']:.2f} (previous: {signals['sma_5_prev']:.2f})\n",
" 20-day SMA: {signals['sma_20_curr']:.2f} (previous: {signals['sma_20_prev']:.2f})\n",
"\n",
" Portfolio:\n",
" Cash: ${portfolio['cash']:.2f}\n",
" Shares: {portfolio['stock']}\n",
"\n",
" Based on the SMA crossover strategy and current portfolio, what is your trading decision?\n",
" \"\"\"\n",
"\n",
" return {\"messages\": [gpt_4o_model.invoke(messages)]}\n",
"\n",
"\n",
"# Define the agent graph\n",
"workflow = StateGraph(MessagesState)\n",
"workflow.add_node(\"agent\", call_agent)\n",
"workflow.set_entry_point(\"agent\")\n",
"app = workflow.compile()\n",
"\n",
"\n",
"# Update the run_agent function to include portfolio state\n",
"def run_agent(ticker: str, start_date: str, end_date: str, portfolio: dict):\n",
" final_state = app.invoke(\n",
" {\n",
" \"messages\": [\n",
" HumanMessage(\n",
" content=\"Make a trading decision based on the provided data.\",\n",
" additional_kwargs={\n",
" \"ticker\": ticker,\n",
" \"start_date\": start_date,\n",
" \"end_date\": end_date,\n",
" \"portfolio\": portfolio\n",
" },\n",
" )\n",
" ]\n",
" },\n",
" config={\"configurable\": {\"thread_id\": 42}},\n",
" )\n",
" return final_state[\"messages\"][-1].content"
],
"metadata": {
"id": "tuMK7Jrxj3eY"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"run_agent(\"hello\")"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 35
},
"id": "UnlmOdL8j7QD",
"outputId": "4197be79-43fb-440f-f1dd-f2fdd5d9b282"
},
"execution_count": null,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"'hold'"
],
"application/vnd.google.colaboratory.intrinsic+json": {
"type": "string"
}
},
"metadata": {},
"execution_count": 6
}
]
},
{
"cell_type": "markdown",
"source": [
"# 2. Get Price Data"
],
"metadata": {
"id": "DsvZo7gjpMGu"
}
},
{
"cell_type": "code",
"source": [
"def get_price_data(ticker, start_date, end_date):\n",
" # Add your API key to the headers\n",
" headers = {\n",
" \"X-API-KEY\": os.environ.get(\"FINANCIAL_DATASETS_API_KEY\")\n",
" }\n",
"\n",
" # Create the URL\n",
" url = (\n",
" f'https://api.financialdatasets.ai/prices/'\n",
" f'?ticker={ticker}'\n",
" f'&interval=day'\n",
" f'&interval_multiplier=1'\n",
" f'&start_date={start_date}'\n",
" f'&end_date={end_date}'\n",
" )\n",
"\n",
" # Make API request\n",
" response = requests.get(url, headers=headers)\n",
"\n",
" # Check for successful response\n",
" if response.status_code != 200:\n",
" raise Exception(f\"Error fetching data: {response.status_code} - {response.text}\")\n",
"\n",
" # Parse prices from the response\n",
" data = response.json()\n",
" prices = data.get('prices')\n",
" if not prices:\n",
" raise ValueError(\"No price data returned\")\n",
"\n",
" # Convert prices to DataFrame\n",
" df = pd.DataFrame(prices)\n",
"\n",
" # Convert 'time' to datetime and set as index\n",
" df['Date'] = pd.to_datetime(df['time'])\n",
" df.set_index('Date', inplace=True)\n",
"\n",
" # Ensure numeric data types\n",
" numeric_cols = ['open', 'close', 'high', 'low', 'volume']\n",
" for col in numeric_cols:\n",
" df[col] = pd.to_numeric(df[col], errors='coerce')\n",
"\n",
" # Sort by date\n",
" df.sort_index(inplace=True)\n",
"\n",
" return df\n"
],
"metadata": {
"id": "tCYkY3EDmCLM"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"# 3. Define Trading Strategy"
],
"metadata": {
"id": "4R8X9AgfKhoz"
}
},
{
"cell_type": "code",
"source": [
"# Define a function to calculate trading signals\n",
"def calculate_trading_signals(historical_data: pd.DataFrame) -> dict:\n",
" \"\"\"Calculate trading signals based on SMA crossover strategy\"\"\"\n",
" # Calculate SMAs\n",
" sma_5 = historical_data[\"close\"].rolling(window=5).mean()\n",
" sma_20 = historical_data[\"close\"].rolling(window=20).mean()\n",
"\n",
" # Get the last two points of each SMA to check for crossover\n",
" sma_5_prev, sma_5_curr = sma_5.iloc[-2:]\n",
" sma_20_prev, sma_20_curr = sma_20.iloc[-2:]\n",
"\n",
" return {\n",
" \"current_price\": historical_data[\"close\"].iloc[-1],\n",
" \"sma_5_curr\": sma_5_curr,\n",
" \"sma_5_prev\": sma_5_prev,\n",
" \"sma_20_curr\": sma_20_curr,\n",
" \"sma_20_prev\": sma_20_prev,\n",
" }"
],
"metadata": {
"id": "25BbhA_FKj3k"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"# 4. Create a backtester"
],
"metadata": {
"id": "eib4alsrpQ3h"
}
},
{
"cell_type": "code",
"source": [
"class Backtester:\n",
" def __init__(self, agent, ticker, start_date, end_date, initial_capital):\n",
" self.agent = agent\n",
" self.ticker = ticker\n",
" self.start_date = start_date\n",
" self.end_date = end_date\n",
" self.initial_capital = initial_capital\n",
" self.portfolio = {\"cash\": initial_capital, \"stock\": 0}\n",
" self.portfolio_values = []\n",
"\n",
" def parse_action(self, agent_output):\n",
" try:\n",
" # Expect JSON output from agent\n",
" import json\n",
" decision = json.loads(agent_output)\n",
" return decision[\"action\"], decision[\"quantity\"]\n",
" except:\n",
" return \"hold\", 0\n",
"\n",
" def execute_trade(self, action, quantity, current_price):\n",
" \"\"\"Validate and execute trades based on portfolio constraints\"\"\"\n",
" if action == \"buy\" and quantity > 0:\n",
" cost = quantity * current_price\n",
" if cost <= self.portfolio[\"cash\"]:\n",
" self.portfolio[\"stock\"] += quantity\n",
" self.portfolio[\"cash\"] -= cost\n",
" return quantity\n",
" else:\n",
" # Calculate maximum affordable quantity\n",
" max_quantity = self.portfolio[\"cash\"] // current_price\n",
" if max_quantity > 0:\n",
" self.portfolio[\"stock\"] += max_quantity\n",
" self.portfolio[\"cash\"] -= max_quantity * current_price\n",
" return max_quantity\n",
" return 0\n",
" elif action == \"sell\" and quantity > 0:\n",
" quantity = min(quantity, self.portfolio[\"stock\"])\n",
" if quantity > 0:\n",
" self.portfolio[\"cash\"] += quantity * current_price\n",
" self.portfolio[\"stock\"] -= quantity\n",
" return quantity\n",
" return 0\n",
" return 0\n",
"\n",
" def run_backtest(self):\n",
" dates = pd.date_range(self.start_date, self.end_date, freq=\"B\")\n",
"\n",
" print(\"\\nStarting backtest...\")\n",
" print(f\"{'Date':<12} {'Action':<6} {'Quantity':>8} {'Price':>8} {'Cash':>12} {'Stock':>8} {'Total Value':>12}\")\n",
" print(\"-\" * 70)\n",
"\n",
" for current_date in dates:\n",
" lookback_start = (current_date - timedelta(days=30)).strftime(\"%Y-%m-%d\")\n",
" current_date_str = current_date.strftime(\"%Y-%m-%d\")\n",
"\n",
" agent_output = self.agent(\n",
" ticker=self.ticker,\n",
" start_date=lookback_start,\n",
" end_date=current_date_str,\n",
" portfolio=self.portfolio\n",
" )\n",
"\n",
" action, quantity = self.parse_action(agent_output)\n",
" df = get_price_data(self.ticker, lookback_start, current_date_str)\n",
" current_price = df.iloc[-1]['close']\n",
"\n",
" # Execute the trade with validation\n",
" executed_quantity = self.execute_trade(action, quantity, current_price)\n",
"\n",
" # Update total portfolio value\n",
" total_value = self.portfolio[\"cash\"] + self.portfolio[\"stock\"] * current_price\n",
" self.portfolio[\"portfolio_value\"] = total_value\n",
"\n",
" # Log the current state with executed quantity\n",
" print(\n",
" f\"{current_date.strftime('%Y-%m-%d'):<12} {action:<6} {executed_quantity:>8} {current_price:>8.2f} \"\n",
" f\"{self.portfolio['cash']:>12.2f} {self.portfolio['stock']:>8} {total_value:>12.2f}\"\n",
" )\n",
"\n",
" # Record the portfolio value\n",
" self.portfolio_values.append(\n",
" {\"Date\": current_date, \"Portfolio Value\": total_value}\n",
" )\n",
"\n",
" def analyze_performance(self):\n",
" # Convert portfolio values to DataFrame\n",
" performance_df = pd.DataFrame(self.portfolio_values).set_index(\"Date\")\n",
"\n",
" # Calculate total return\n",
" total_return = (\n",
" self.portfolio[\"portfolio_value\"] - self.initial_capital\n",
" ) / self.initial_capital\n",
" print(f\"Total Return: {total_return * 100:.2f}%\")\n",
"\n",
" # Plot the portfolio value over time\n",
" performance_df[\"Portfolio Value\"].plot(\n",
" title=\"Portfolio Value Over Time\", figsize=(12, 6)\n",
" )\n",
" plt.ylabel(\"Portfolio Value ($)\")\n",
" plt.xlabel(\"Date\")\n",
" plt.show()\n",
"\n",
" # Compute daily returns\n",
" performance_df[\"Daily Return\"] = performance_df[\"Portfolio Value\"].pct_change()\n",
"\n",
" # Calculate Sharpe Ratio (assuming 252 trading days in a year)\n",
" mean_daily_return = performance_df[\"Daily Return\"].mean()\n",
" std_daily_return = performance_df[\"Daily Return\"].std()\n",
" sharpe_ratio = (mean_daily_return / std_daily_return) * (252**0.5)\n",
" print(f\"Sharpe Ratio: {sharpe_ratio:.2f}\")\n",
"\n",
" # Calculate Maximum Drawdown\n",
" rolling_max = performance_df[\"Portfolio Value\"].cummax()\n",
" drawdown = performance_df[\"Portfolio Value\"] / rolling_max - 1\n",
" max_drawdown = drawdown.min()\n",
" print(f\"Maximum Drawdown: {max_drawdown * 100:.2f}%\")\n",
"\n",
" return performance_df"
],
"metadata": {
"id": "sycJCgYunBfq"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"# 4. Run the Backtest"
],
"metadata": {
"id": "8gbjCJT-pTa0"
}
},
{
"cell_type": "code",
"source": [
"# Define parameters\n",
"ticker = \"AAPL\" # Example ticker symbol\n",
"start_date = \"2024-01-01\" # Adjust as needed\n",
"end_date = \"2024-10-31\" # Adjust as needed\n",
"initial_capital = 100000 # $100,000\n",
"\n",
"# Create an instance of Backtester\n",
"backtester = Backtester(\n",
" agent=run_agent,\n",
" ticker=ticker,\n",
" start_date=start_date,\n",
" end_date=end_date,\n",
" initial_capital=initial_capital,\n",
")\n",
"\n",
"# Run the backtesting process\n",
"backtester.run_backtest()\n",
"performance_df = backtester.analyze_performance()"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 959
},
"id": "ro6_juA_nHl-",
"outputId": "02f848e0-baa6-4f5d-b0d7-30b27292408e"
},
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"\n",
"Starting backtest...\n",
"Date Action Price Shares Cash Total Value\n",
"------------------------------------------------------------\n",
"2024-01-09 buy 185.14 540 24.40 100000.00\n",
"2024-01-10 hold 186.19 540 24.40 100567.00\n",
"2024-01-11 buy 185.59 540 24.40 100243.00\n",
"2024-01-12 hold 185.92 540 24.40 100421.20\n",
"2024-01-16 hold 183.63 540 24.40 99184.60\n",
"2024-01-17 sell 182.68 0 98671.60 98671.60\n",
"2024-01-18 buy 188.63 523 18.11 98671.60\n",
"2024-01-19 buy 191.56 523 18.11 100203.99\n",
"2024-01-22 buy 193.89 523 18.11 101422.58\n",
"2024-01-23 buy 195.18 523 18.11 102097.25\n",
"2024-01-24 hold 194.50 523 18.11 101741.61\n",
"2024-01-25 hold 194.17 523 18.11 101569.02\n",
"2024-01-26 hold 192.42 523 18.11 100653.77\n",
"2024-01-29 sell 191.73 0 100292.90 100292.90\n",
"2024-01-30 buy 188.04 533 67.58 100292.90\n",
"2024-01-31 buy 184.40 533 67.58 98352.78\n",
"Total Return: -1.65%\n"
]
},
{
"output_type": "display_data",
"data": {
"text/plain": [
"<Figure size 1200x600 with 1 Axes>"
],
"image/png": "\n"
},
"metadata": {}
},
{
"output_type": "stream",
"name": "stdout",
"text": [
"Sharpe Ratio: -1.89\n",
"Maximum Drawdown: -3.67%\n"
]
}
]
},
{
"cell_type": "code",
"source": [],
"metadata": {
"id": "3qmllXP2kMH4"
},
"execution_count": null,
"outputs": []
}
]
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment