Created
May 18, 2023 08:02
-
-
Save Qu3tzal/f59063a2e8b8df2d27a435cabfa500b6 to your computer and use it in GitHub Desktop.
Quickly compare a list of stocks or indices prices.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import yfinance as yf | |
import pandas as pd | |
import numpy as np | |
import matplotlib.pyplot as plt | |
import argparse | |
def parse_args(): | |
parser = argparse.ArgumentParser(description='Quickly compare a list of stocks or indices prices.') | |
parser.add_argument('--history', type=str, default='ytd', help='History period of the price data', choices=["1d", "5d", "1mo", "3mo", "6mo", "1y", "2y", "5y", "10y", "ytd", "max"]) | |
parser.add_argument('--granularity', type=str, default='1d', help='Granularity of the price data', choices=["1m", "2m", "5m", "15m", "30m", "60m", "90m", "1h", "1d", "5d", "1wk", "1mo", "3mo"]) | |
parser.add_argument('--stocks', nargs='+', help='List of stocks to compare') | |
parser.add_argument('--indices', nargs='+', help='List of indices to compare', choices=["SP500", "NASDAQ", "DOWJONES"]) | |
parser.add_argument('--log', action='store_true', help='Use log scale for the y-axis') | |
parser.add_argument('--timecolor', action='store_true', help='Use color to indicate time') | |
parser.add_argument('--pricelines', action='store_true', help='Adds lines at regular price intervals') | |
parser.add_argument('--theme', type=str, default='dark', help='Theme of the plot', choices=["dark", "light", "seaborn"]) | |
args = parser.parse_args() | |
return args | |
def main(args): | |
print("Hello!") | |
if args.stocks is None and args.indices is None: | |
print("You need to specify at least one stock or index to compare.") | |
return | |
elif args.stocks is None: | |
print("You are comparing the following indices: {}".format(args.indices)) | |
elif args.indices is None: | |
print("You are comparing the following stocks: {}".format(args.stocks)) | |
else: | |
print("You are comparing the following stocks: {} with the following indices {}".format(args.stocks, args.indices)) | |
if args.indices is not None and args.stocks is not None: | |
print("WARNING: You are comparing stocks and indices on the same plot. This can make reading the plot difficult. Consider using --stocks and --indices separately.") | |
if args.stocks is not None and len(args.stocks) > 10: | |
print("WARNING: You are comparing more than 10 stocks. This can make reading the plot difficult.") | |
if args.theme is not None: | |
themes_map = { | |
"dark": "dark_background", | |
"light": "default", | |
"seaborn": "seaborn", | |
} | |
if args.theme == "seaborn": | |
import seaborn as sns | |
sns.set_theme() | |
print("WARNING: When using the seaborn theme consider disabling the --timecolor and --pricelines flags.") | |
else: | |
plt.style.use(themes_map[args.theme]) | |
# Get the data for the indices. | |
if args.indices: | |
indices_ticker_map = { | |
"SP500": "^GSPC", | |
"NASDAQ": "^IXIC", | |
"DOWJONES": "^DJI" | |
} | |
indices_tickers = [indices_ticker_map[index] for index in args.indices] | |
indices_data = yf.download(indices_tickers, period=args.history, interval=args.granularity, progress=False) | |
if len(indices_tickers) == 1: | |
indices_data.columns = pd.MultiIndex.from_product([indices_data.columns, indices_tickers]) | |
for ticker in args.indices: | |
close_price = indices_data["Close", indices_ticker_map[ticker]].to_numpy() | |
plt.plot(close_price, linestyle='dashed', label=ticker + " ({})".format(indices_ticker_map[ticker])) | |
# Get the data for the stocks. | |
data = yf.download(args.stocks, period=args.history, interval=args.granularity, progress=False) | |
if len(args.stocks) == 1: | |
data.columns = pd.MultiIndex.from_product([data.columns, args.stocks]) | |
for ticker in args.stocks: | |
close_price = data["Close", ticker].to_numpy() | |
high_price = data["High", ticker].to_numpy() | |
low_price = data["Low", ticker].to_numpy() | |
plt.plot(close_price, label=ticker) | |
plt.fill_between(range(len(close_price)), high_price, low_price, alpha=0.3) | |
# Try to only keep about a dozen labels. | |
if len(data.index) > 12: | |
xticks_indices = range(0, len(data.index), int(len(data.index) / 12)) | |
xticks_labels = data.index[xticks_indices] | |
else: | |
xticks_indices = range(len(data.index)) | |
xticks_labels = data.index | |
# Adjust labels depending on granularity. | |
if args.granularity in ["1m", "2m", "5m", "15m", "30m", "60m", "90m", "1h"]: | |
xticks_labels = [date.strftime("%H:%M\n%d/%m/%Y") for date in xticks_labels] | |
elif args.granularity in ["1d", "5d", "1wk"]: | |
xticks_labels = [date.strftime("%d/%m/%Y") for date in xticks_labels] | |
elif args.granularity in ["1mo", "3mo"]: | |
xticks_labels = [date.strftime("%d/%m/%Y") for date in xticks_labels] | |
else: | |
xticks_labels = [date.strftime("%m/%Y") for date in xticks_labels] | |
if args.log: | |
plt.yscale("log") | |
if args.timecolor: | |
for i, xi in enumerate(xticks_indices): | |
if i % 2 == 0 and i < len(xticks_indices) - 1: | |
plt.axvspan(xi, xticks_indices[i + 1], facecolor='grey', alpha=0.1) | |
if args.pricelines: | |
plt.grid(True) | |
plt.xticks(xticks_indices, xticks_labels, rotation=45) | |
plt.xlabel("Date") | |
plt.ylabel("Price (at close, with high/low)") | |
plt.legend() | |
plt.title("Comparison of close prices of stocks ({}/{})".format(args.history, args.granularity)) | |
plt.margins(x=0, tight=True) | |
plt.show() | |
if __name__ == "__main__": | |
args = parse_args() | |
main(args) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment