Created
October 25, 2024 13:26
-
-
Save namuan/65e6ba8e3793095e34c819cae2accfad to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python3 | |
""" | |
Script for plotting animated moving averages of stock prices. | |
Examples: | |
python3 spy_30y.py --show | |
""" | |
import logging | |
import sys | |
from argparse import ArgumentParser, RawDescriptionHelpFormatter | |
from dataclasses import dataclass | |
from datetime import datetime, timedelta | |
from typing import Optional, Tuple | |
import matplotlib.pyplot as plt | |
import numpy as np | |
import pandas as pd | |
import seaborn as sns | |
from matplotlib.animation import FuncAnimation, PillowWriter | |
from pandas import DataFrame | |
from common.market import get_cached_data | |
# Constants | |
MOVING_AVERAGES = [5, 20, 50, 100, 150, 200] | |
DEFAULT_SYMBOL = "SPY" | |
DEFAULT_PERIOD = 1 | |
DEFAULT_LOOKBACK_YEARS = 30 | |
DATE_FORMAT = "%Y-%m-%d" | |
UP_COLOR = "#20A428" # Dark green | |
DOWN_COLOR = "#EB2119" # Dark red | |
FIGURE_SIZE = (15, 8) | |
DPI = 300 | |
ANIMATION_INTERVAL = 50 # milliseconds between frames | |
ANIMATION_TRAIL_LENGTH = 50 # number of points to show in trail | |
def setup_logging(verbosity: int) -> None: | |
""" | |
Set up logging based on verbosity level. | |
Args: | |
verbosity: Integer indicating verbosity level (0=WARNING, 1=INFO, 2+=DEBUG) | |
""" | |
logging_level = logging.WARNING | |
if verbosity == 1: | |
logging_level = logging.INFO | |
elif verbosity >= 2: | |
logging_level = logging.DEBUG | |
root = logging.getLogger() | |
if root.handlers: | |
for handler in root.handlers: | |
root.removeHandler(handler) | |
logging.basicConfig( | |
stream=sys.stdout, | |
format="%(asctime)s - %(levelname)s - %(filename)s:%(lineno)d - %(message)s", | |
datefmt=DATE_FORMAT, | |
level=logging_level, | |
) | |
logging.debug( | |
"Logging configured with level: %s", logging.getLevelName(logging_level) | |
) | |
def validate_dates(start_date: datetime.date, end_date: datetime.date) -> bool: | |
""" | |
Validate that start_date is before end_date and both are valid dates. | |
Args: | |
start_date: Starting date for analysis | |
end_date: Ending date for analysis | |
Returns: | |
bool: True if dates are valid, False otherwise | |
""" | |
if start_date > end_date: | |
logging.error("Start date must be before end date") | |
return False | |
if end_date > datetime.now().date(): | |
logging.error("End date cannot be in the future") | |
return False | |
return True | |
@dataclass | |
class ArgOptions: | |
"""Class to hold command line argument options.""" | |
verbose: int | |
symbol: str | |
year: Optional[int] | |
start_date: Optional[str] | |
end_date: Optional[str] | |
period: int | |
output: Optional[str] | |
show: bool | |
def validate(self) -> bool: | |
""" | |
Validate the argument options. | |
Returns: | |
bool: True if all validations pass, False otherwise | |
""" | |
if self.period <= 0: | |
logging.error("Period must be a positive integer") | |
return False | |
if self.start_date and self.end_date: | |
try: | |
start = datetime.strptime(self.start_date, DATE_FORMAT).date() | |
end = datetime.strptime(self.end_date, DATE_FORMAT).date() | |
if start > end: | |
logging.error("Start date must be before end date") | |
return False | |
except ValueError as e: | |
logging.error(f"Invalid date format: {str(e)}") | |
return False | |
if not self.output and not self.show: | |
logging.error("Either --output or --show (or both) must be specified") | |
return False | |
return True | |
def parse_args() -> ArgOptions: | |
""" | |
Parse and validate command-line arguments. | |
Returns: | |
ArgOptions: Validated command line arguments | |
""" | |
parser = ArgumentParser( | |
description=__doc__, formatter_class=RawDescriptionHelpFormatter | |
) | |
parser.add_argument( | |
"-v", | |
"--verbose", | |
action="count", | |
default=0, | |
dest="verbose", | |
help="Increase verbosity of logging output", | |
) | |
parser.add_argument( | |
"-s", | |
"--symbol", | |
type=str, | |
default=DEFAULT_SYMBOL, | |
help=f"Stock symbol to analyze (default: {DEFAULT_SYMBOL})", | |
) | |
parser.add_argument("-y", "--year", type=int, help="Starting year for analysis") | |
parser.add_argument( | |
"-sd", | |
"--start_date", | |
type=str, | |
help="Start date for analysis (format: YYYY-MM-DD)", | |
) | |
parser.add_argument( | |
"-ed", | |
"--end_date", | |
type=str, | |
help="End date for analysis (format: YYYY-MM-DD, default: today)", | |
) | |
parser.add_argument( | |
"-p", | |
"--period", | |
type=int, | |
default=DEFAULT_PERIOD, | |
help=f"Period for price change calculation (in days, default: {DEFAULT_PERIOD})", | |
) | |
parser.add_argument( | |
"-o", | |
"--output", | |
type=str, | |
help="Output file path to save the animation (e.g., /path/to/plot.gif)", | |
) | |
parser.add_argument( | |
"--show", | |
action="store_true", | |
help="Display the animation window (default: False)", | |
) | |
args = parser.parse_args() | |
options = ArgOptions( | |
verbose=args.verbose, | |
symbol=args.symbol, | |
year=args.year, | |
start_date=args.start_date, | |
end_date=args.end_date, | |
period=args.period, | |
output=args.output, | |
show=args.show, | |
) | |
if not options.validate(): | |
sys.exit(1) | |
return options | |
def fetch_stock_data( | |
symbol: str, start_date: datetime.date, end_date: datetime.date | |
) -> DataFrame: | |
""" | |
Fetch stock data for the given symbol and date range. | |
Args: | |
symbol: Stock symbol | |
start_date: Start date for data fetch | |
end_date: End date for data fetch | |
Returns: | |
DataFrame: Stock price data | |
""" | |
logging.debug(f"Fetching data for {symbol} from {start_date} to {end_date}") | |
try: | |
df = get_cached_data(symbol, start=start_date, end=end_date) | |
if df.empty: | |
logging.warning(f"Data for symbol {symbol} is empty") | |
return df | |
except Exception as e: | |
logging.error(f"Error fetching data for {symbol}: {str(e)}") | |
return pd.DataFrame() | |
def handle_dates(args: ArgOptions) -> Tuple[datetime.date, datetime.date]: | |
""" | |
Handle the date parsing for start and end dates. | |
Args: | |
args: Command line arguments | |
Returns: | |
Tuple[datetime.date, datetime.date]: Start and end dates | |
""" | |
try: | |
end_date = ( | |
datetime.strptime(args.end_date, DATE_FORMAT).date() | |
if args.end_date | |
else datetime.now().date() | |
) | |
if args.start_date: | |
start_date = datetime.strptime(args.start_date, DATE_FORMAT).date() | |
elif args.year: | |
start_date = datetime(args.year, 1, 1).date() | |
else: | |
start_date = end_date - timedelta(days=DEFAULT_LOOKBACK_YEARS * 365) | |
if not validate_dates(start_date, end_date): | |
sys.exit(1) | |
logging.debug(f"Date range: {start_date} to {end_date}") | |
return start_date, end_date | |
except ValueError as e: | |
logging.error(f"Invalid date format: {str(e)}") | |
sys.exit(1) | |
import matplotlib.dates | |
def create_animated_plot(df: DataFrame, symbol: str, args: ArgOptions) -> None: | |
""" | |
Create an animated plot of moving averages with dynamic historical event annotations. | |
Args: | |
df: DataFrame containing stock data and moving averages | |
symbol: Stock symbol being plotted | |
args: Command line arguments | |
""" | |
sns.set_style("darkgrid") | |
fig, ax = plt.subplots(figsize=FIGURE_SIZE) | |
plt.grid(False) | |
# Define historical events with cool colors for each event | |
events = { | |
"2000-08-24": ("Dot-com Bubble Burst", "#1f77b4"), # Steel Blue | |
"2007-11-23": ("Global Financial Crisis", "#1f77b4"), | |
"2020-03-02": ("COVID-19 Crash", "#1f77b4"), | |
"2022-01-10": ("2022 Interest rates downturn", "#1f77b4"), | |
} | |
# Convert event dates to datetime and get their indices | |
event_dates = { | |
pd.to_datetime(date): (name, color) for date, (name, color) in events.items() | |
} | |
event_indices = { | |
df.index.get_loc(date): (date, name, color) | |
for date, (name, color) in event_dates.items() | |
if date in df.index | |
} | |
min_ma = min(MOVING_AVERAGES) | |
marker_sizes = [1 * (ma / min_ma) for ma in MOVING_AVERAGES] | |
# Set up the plot limits | |
all_values = [] | |
for ma in MOVING_AVERAGES: | |
all_values.extend(df[f"MA{ma}"].dropna().tolist()) | |
y_min, y_max = min(all_values), max(all_values) | |
y_padding = (y_max - y_min) * 0.1 | |
# Convert index to numerical values for animation | |
df["date_num"] = matplotlib.dates.date2num(df.index) | |
ax.set_xlim(df.index[0], df.index[-1]) | |
ax.set_ylim(y_min - y_padding, y_max + y_padding) | |
# Initialize scatter plots for each MA | |
scatters = [] | |
for ma, marker_size in zip(MOVING_AVERAGES, marker_sizes): | |
scatter_up = ax.scatter( | |
[], [], c=UP_COLOR, s=marker_size, alpha=0.6, label=f"{ma}-Day MA" | |
) | |
scatter_down = ax.scatter([], [], c=DOWN_COLOR, s=marker_size, alpha=0.3) | |
scatters.extend([scatter_up, scatter_down]) | |
plt.title(f"{symbol} Moving Averages", fontsize=14, pad=20) | |
plt.xlabel("Date", fontsize=12) | |
plt.ylabel("Price", fontsize=12) | |
plt.xticks(rotation=45) | |
# Store all points for each MA | |
stored_points = {ma: {"up": [], "down": []} for ma in MOVING_AVERAGES} | |
# Create a list to store annotations that should persist | |
persistent_annotations = [] | |
def animate(frame): | |
frame_step = 10 | |
current_frame = frame * frame_step | |
if current_frame >= len(df): | |
return scatters + persistent_annotations | |
# Handle annotations | |
for idx, (date, name, color) in event_indices.items(): | |
if current_frame >= idx: | |
# Check if we already created this annotation | |
if not any(ann.get_text() == name for ann in persistent_annotations): | |
# Calculate price for annotation | |
price = df[f"MA{min_ma}"][idx] | |
# Add annotation with custom color | |
ann = ax.annotate( | |
name, | |
xy=(date, price), | |
xytext=(-50, 50), | |
textcoords="offset points", | |
bbox=dict( | |
boxstyle="round,pad=0.5", fc=color, alpha=0.3, ec=color | |
), | |
color="white", | |
arrowprops=dict( | |
arrowstyle="->", | |
connectionstyle="arc3,rad=-0.3", | |
alpha=0.6, | |
color=color, | |
), | |
ha="left", | |
va="center", | |
) | |
persistent_annotations.append(ann) | |
for i, ma in enumerate(MOVING_AVERAGES): | |
ma_col = f"MA{ma}" | |
daily_changes = df[ma_col].diff() | |
# Add new points to stored points | |
if current_frame > 0: | |
# Positive changes | |
positive_mask = daily_changes > 0 | |
pos_dates = df["date_num"][current_frame - frame_step : current_frame][ | |
positive_mask[current_frame - frame_step : current_frame] | |
] | |
pos_values = df[ma_col][current_frame - frame_step : current_frame][ | |
positive_mask[current_frame - frame_step : current_frame] | |
] | |
if len(pos_dates) > 0: | |
stored_points[ma]["up"].extend(np.c_[pos_dates, pos_values]) | |
# Negative changes | |
negative_mask = daily_changes < 0 | |
neg_dates = df["date_num"][current_frame - frame_step : current_frame][ | |
negative_mask[current_frame - frame_step : current_frame] | |
] | |
neg_values = df[ma_col][current_frame - frame_step : current_frame][ | |
negative_mask[current_frame - frame_step : current_frame] | |
] | |
if len(neg_dates) > 0: | |
stored_points[ma]["down"].extend(np.c_[neg_dates, neg_values]) | |
# Update scatter plots with all stored points | |
if stored_points[ma]["up"]: | |
scatters[i * 2].set_offsets(np.array(stored_points[ma]["up"])) | |
if stored_points[ma]["down"]: | |
scatters[i * 2 + 1].set_offsets(np.array(stored_points[ma]["down"])) | |
# Return all artists that need to be redrawn | |
return scatters + persistent_annotations | |
# Format x-axis dates | |
ax.xaxis.set_major_formatter(matplotlib.dates.DateFormatter("%Y-%m-%d")) | |
frames = len(df) // 5 | |
anim = FuncAnimation(fig, animate, frames=frames, interval=10, blit=True) | |
plt.tight_layout() | |
if args.output: | |
writer = PillowWriter(fps=30) | |
anim.save(args.output, writer=writer) | |
logging.info(f"Animation saved to {args.output}") | |
if args.show: | |
plt.show() | |
def calculate_moving_averages(df: DataFrame) -> DataFrame: | |
""" | |
Calculate moving averages for the given DataFrame. | |
Args: | |
df: DataFrame containing stock price data | |
Returns: | |
DataFrame: Original DataFrame with added moving average columns | |
""" | |
for ma in MOVING_AVERAGES: | |
column_name = f"MA{ma}" | |
df[column_name] = df["Close"].rolling(window=ma).mean() | |
logging.debug(f"Calculated {ma}-day moving average") | |
# Drop rows with NaN values in moving average columns | |
df.dropna(subset=[f"MA{ma}" for ma in MOVING_AVERAGES], inplace=True) | |
return df | |
def main(args: ArgOptions) -> None: | |
""" | |
Main function to handle stock analysis and visualization. | |
Args: | |
args: Validated command line arguments | |
""" | |
logging.info(f"Starting analysis with verbosity level: {args.verbose}") | |
logging.debug("Debug logging is enabled") | |
start_date, end_date = handle_dates(args) | |
df = fetch_stock_data(args.symbol, start_date, end_date) | |
if df.empty: | |
logging.error(f"No data available for {args.symbol} in the given date range") | |
sys.exit(1) | |
logging.info(f"Successfully loaded data for {args.symbol}") | |
df = calculate_moving_averages(df) | |
create_animated_plot(df, args.symbol, args) | |
plt.close() | |
if __name__ == "__main__": | |
args = parse_args() | |
setup_logging(args.verbose) | |
main(args) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment