Skip to content

Instantly share code, notes, and snippets.

@namuan
Created October 25, 2024 13:26
Show Gist options
  • Save namuan/65e6ba8e3793095e34c819cae2accfad to your computer and use it in GitHub Desktop.
Save namuan/65e6ba8e3793095e34c819cae2accfad to your computer and use it in GitHub Desktop.
#!/usr/bin/env python3
"""
Script for plotting animated moving averages of stock prices.
Examples:
python3 spy_30y.py --show
"""
import logging
import sys
from argparse import ArgumentParser, RawDescriptionHelpFormatter
from dataclasses import dataclass
from datetime import datetime, timedelta
from typing import Optional, Tuple
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
from matplotlib.animation import FuncAnimation, PillowWriter
from pandas import DataFrame
from common.market import get_cached_data
# Constants
MOVING_AVERAGES = [5, 20, 50, 100, 150, 200]
DEFAULT_SYMBOL = "SPY"
DEFAULT_PERIOD = 1
DEFAULT_LOOKBACK_YEARS = 30
DATE_FORMAT = "%Y-%m-%d"
UP_COLOR = "#20A428" # Dark green
DOWN_COLOR = "#EB2119" # Dark red
FIGURE_SIZE = (15, 8)
DPI = 300
ANIMATION_INTERVAL = 50 # milliseconds between frames
ANIMATION_TRAIL_LENGTH = 50 # number of points to show in trail
def setup_logging(verbosity: int) -> None:
"""
Set up logging based on verbosity level.
Args:
verbosity: Integer indicating verbosity level (0=WARNING, 1=INFO, 2+=DEBUG)
"""
logging_level = logging.WARNING
if verbosity == 1:
logging_level = logging.INFO
elif verbosity >= 2:
logging_level = logging.DEBUG
root = logging.getLogger()
if root.handlers:
for handler in root.handlers:
root.removeHandler(handler)
logging.basicConfig(
stream=sys.stdout,
format="%(asctime)s - %(levelname)s - %(filename)s:%(lineno)d - %(message)s",
datefmt=DATE_FORMAT,
level=logging_level,
)
logging.debug(
"Logging configured with level: %s", logging.getLevelName(logging_level)
)
def validate_dates(start_date: datetime.date, end_date: datetime.date) -> bool:
"""
Validate that start_date is before end_date and both are valid dates.
Args:
start_date: Starting date for analysis
end_date: Ending date for analysis
Returns:
bool: True if dates are valid, False otherwise
"""
if start_date > end_date:
logging.error("Start date must be before end date")
return False
if end_date > datetime.now().date():
logging.error("End date cannot be in the future")
return False
return True
@dataclass
class ArgOptions:
"""Class to hold command line argument options."""
verbose: int
symbol: str
year: Optional[int]
start_date: Optional[str]
end_date: Optional[str]
period: int
output: Optional[str]
show: bool
def validate(self) -> bool:
"""
Validate the argument options.
Returns:
bool: True if all validations pass, False otherwise
"""
if self.period <= 0:
logging.error("Period must be a positive integer")
return False
if self.start_date and self.end_date:
try:
start = datetime.strptime(self.start_date, DATE_FORMAT).date()
end = datetime.strptime(self.end_date, DATE_FORMAT).date()
if start > end:
logging.error("Start date must be before end date")
return False
except ValueError as e:
logging.error(f"Invalid date format: {str(e)}")
return False
if not self.output and not self.show:
logging.error("Either --output or --show (or both) must be specified")
return False
return True
def parse_args() -> ArgOptions:
"""
Parse and validate command-line arguments.
Returns:
ArgOptions: Validated command line arguments
"""
parser = ArgumentParser(
description=__doc__, formatter_class=RawDescriptionHelpFormatter
)
parser.add_argument(
"-v",
"--verbose",
action="count",
default=0,
dest="verbose",
help="Increase verbosity of logging output",
)
parser.add_argument(
"-s",
"--symbol",
type=str,
default=DEFAULT_SYMBOL,
help=f"Stock symbol to analyze (default: {DEFAULT_SYMBOL})",
)
parser.add_argument("-y", "--year", type=int, help="Starting year for analysis")
parser.add_argument(
"-sd",
"--start_date",
type=str,
help="Start date for analysis (format: YYYY-MM-DD)",
)
parser.add_argument(
"-ed",
"--end_date",
type=str,
help="End date for analysis (format: YYYY-MM-DD, default: today)",
)
parser.add_argument(
"-p",
"--period",
type=int,
default=DEFAULT_PERIOD,
help=f"Period for price change calculation (in days, default: {DEFAULT_PERIOD})",
)
parser.add_argument(
"-o",
"--output",
type=str,
help="Output file path to save the animation (e.g., /path/to/plot.gif)",
)
parser.add_argument(
"--show",
action="store_true",
help="Display the animation window (default: False)",
)
args = parser.parse_args()
options = ArgOptions(
verbose=args.verbose,
symbol=args.symbol,
year=args.year,
start_date=args.start_date,
end_date=args.end_date,
period=args.period,
output=args.output,
show=args.show,
)
if not options.validate():
sys.exit(1)
return options
def fetch_stock_data(
symbol: str, start_date: datetime.date, end_date: datetime.date
) -> DataFrame:
"""
Fetch stock data for the given symbol and date range.
Args:
symbol: Stock symbol
start_date: Start date for data fetch
end_date: End date for data fetch
Returns:
DataFrame: Stock price data
"""
logging.debug(f"Fetching data for {symbol} from {start_date} to {end_date}")
try:
df = get_cached_data(symbol, start=start_date, end=end_date)
if df.empty:
logging.warning(f"Data for symbol {symbol} is empty")
return df
except Exception as e:
logging.error(f"Error fetching data for {symbol}: {str(e)}")
return pd.DataFrame()
def handle_dates(args: ArgOptions) -> Tuple[datetime.date, datetime.date]:
"""
Handle the date parsing for start and end dates.
Args:
args: Command line arguments
Returns:
Tuple[datetime.date, datetime.date]: Start and end dates
"""
try:
end_date = (
datetime.strptime(args.end_date, DATE_FORMAT).date()
if args.end_date
else datetime.now().date()
)
if args.start_date:
start_date = datetime.strptime(args.start_date, DATE_FORMAT).date()
elif args.year:
start_date = datetime(args.year, 1, 1).date()
else:
start_date = end_date - timedelta(days=DEFAULT_LOOKBACK_YEARS * 365)
if not validate_dates(start_date, end_date):
sys.exit(1)
logging.debug(f"Date range: {start_date} to {end_date}")
return start_date, end_date
except ValueError as e:
logging.error(f"Invalid date format: {str(e)}")
sys.exit(1)
import matplotlib.dates
def create_animated_plot(df: DataFrame, symbol: str, args: ArgOptions) -> None:
"""
Create an animated plot of moving averages with dynamic historical event annotations.
Args:
df: DataFrame containing stock data and moving averages
symbol: Stock symbol being plotted
args: Command line arguments
"""
sns.set_style("darkgrid")
fig, ax = plt.subplots(figsize=FIGURE_SIZE)
plt.grid(False)
# Define historical events with cool colors for each event
events = {
"2000-08-24": ("Dot-com Bubble Burst", "#1f77b4"), # Steel Blue
"2007-11-23": ("Global Financial Crisis", "#1f77b4"),
"2020-03-02": ("COVID-19 Crash", "#1f77b4"),
"2022-01-10": ("2022 Interest rates downturn", "#1f77b4"),
}
# Convert event dates to datetime and get their indices
event_dates = {
pd.to_datetime(date): (name, color) for date, (name, color) in events.items()
}
event_indices = {
df.index.get_loc(date): (date, name, color)
for date, (name, color) in event_dates.items()
if date in df.index
}
min_ma = min(MOVING_AVERAGES)
marker_sizes = [1 * (ma / min_ma) for ma in MOVING_AVERAGES]
# Set up the plot limits
all_values = []
for ma in MOVING_AVERAGES:
all_values.extend(df[f"MA{ma}"].dropna().tolist())
y_min, y_max = min(all_values), max(all_values)
y_padding = (y_max - y_min) * 0.1
# Convert index to numerical values for animation
df["date_num"] = matplotlib.dates.date2num(df.index)
ax.set_xlim(df.index[0], df.index[-1])
ax.set_ylim(y_min - y_padding, y_max + y_padding)
# Initialize scatter plots for each MA
scatters = []
for ma, marker_size in zip(MOVING_AVERAGES, marker_sizes):
scatter_up = ax.scatter(
[], [], c=UP_COLOR, s=marker_size, alpha=0.6, label=f"{ma}-Day MA"
)
scatter_down = ax.scatter([], [], c=DOWN_COLOR, s=marker_size, alpha=0.3)
scatters.extend([scatter_up, scatter_down])
plt.title(f"{symbol} Moving Averages", fontsize=14, pad=20)
plt.xlabel("Date", fontsize=12)
plt.ylabel("Price", fontsize=12)
plt.xticks(rotation=45)
# Store all points for each MA
stored_points = {ma: {"up": [], "down": []} for ma in MOVING_AVERAGES}
# Create a list to store annotations that should persist
persistent_annotations = []
def animate(frame):
frame_step = 10
current_frame = frame * frame_step
if current_frame >= len(df):
return scatters + persistent_annotations
# Handle annotations
for idx, (date, name, color) in event_indices.items():
if current_frame >= idx:
# Check if we already created this annotation
if not any(ann.get_text() == name for ann in persistent_annotations):
# Calculate price for annotation
price = df[f"MA{min_ma}"][idx]
# Add annotation with custom color
ann = ax.annotate(
name,
xy=(date, price),
xytext=(-50, 50),
textcoords="offset points",
bbox=dict(
boxstyle="round,pad=0.5", fc=color, alpha=0.3, ec=color
),
color="white",
arrowprops=dict(
arrowstyle="->",
connectionstyle="arc3,rad=-0.3",
alpha=0.6,
color=color,
),
ha="left",
va="center",
)
persistent_annotations.append(ann)
for i, ma in enumerate(MOVING_AVERAGES):
ma_col = f"MA{ma}"
daily_changes = df[ma_col].diff()
# Add new points to stored points
if current_frame > 0:
# Positive changes
positive_mask = daily_changes > 0
pos_dates = df["date_num"][current_frame - frame_step : current_frame][
positive_mask[current_frame - frame_step : current_frame]
]
pos_values = df[ma_col][current_frame - frame_step : current_frame][
positive_mask[current_frame - frame_step : current_frame]
]
if len(pos_dates) > 0:
stored_points[ma]["up"].extend(np.c_[pos_dates, pos_values])
# Negative changes
negative_mask = daily_changes < 0
neg_dates = df["date_num"][current_frame - frame_step : current_frame][
negative_mask[current_frame - frame_step : current_frame]
]
neg_values = df[ma_col][current_frame - frame_step : current_frame][
negative_mask[current_frame - frame_step : current_frame]
]
if len(neg_dates) > 0:
stored_points[ma]["down"].extend(np.c_[neg_dates, neg_values])
# Update scatter plots with all stored points
if stored_points[ma]["up"]:
scatters[i * 2].set_offsets(np.array(stored_points[ma]["up"]))
if stored_points[ma]["down"]:
scatters[i * 2 + 1].set_offsets(np.array(stored_points[ma]["down"]))
# Return all artists that need to be redrawn
return scatters + persistent_annotations
# Format x-axis dates
ax.xaxis.set_major_formatter(matplotlib.dates.DateFormatter("%Y-%m-%d"))
frames = len(df) // 5
anim = FuncAnimation(fig, animate, frames=frames, interval=10, blit=True)
plt.tight_layout()
if args.output:
writer = PillowWriter(fps=30)
anim.save(args.output, writer=writer)
logging.info(f"Animation saved to {args.output}")
if args.show:
plt.show()
def calculate_moving_averages(df: DataFrame) -> DataFrame:
"""
Calculate moving averages for the given DataFrame.
Args:
df: DataFrame containing stock price data
Returns:
DataFrame: Original DataFrame with added moving average columns
"""
for ma in MOVING_AVERAGES:
column_name = f"MA{ma}"
df[column_name] = df["Close"].rolling(window=ma).mean()
logging.debug(f"Calculated {ma}-day moving average")
# Drop rows with NaN values in moving average columns
df.dropna(subset=[f"MA{ma}" for ma in MOVING_AVERAGES], inplace=True)
return df
def main(args: ArgOptions) -> None:
"""
Main function to handle stock analysis and visualization.
Args:
args: Validated command line arguments
"""
logging.info(f"Starting analysis with verbosity level: {args.verbose}")
logging.debug("Debug logging is enabled")
start_date, end_date = handle_dates(args)
df = fetch_stock_data(args.symbol, start_date, end_date)
if df.empty:
logging.error(f"No data available for {args.symbol} in the given date range")
sys.exit(1)
logging.info(f"Successfully loaded data for {args.symbol}")
df = calculate_moving_averages(df)
create_animated_plot(df, args.symbol, args)
plt.close()
if __name__ == "__main__":
args = parse_args()
setup_logging(args.verbose)
main(args)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment