Last active
December 26, 2024 08:49
-
-
Save Stfort52/f4cc97827a0666fe0169d92fd9444b48 to your computer and use it in GitHub Desktop.
Candlestick figure maker for fun
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import pandas as pd | |
import matplotlib.pyplot as plt | |
def draw_candlestick(data, time_axis, window_size, **kwargs) -> plt.Figure: | |
""" | |
Draws a candlestick chart. | |
Parameters: | |
data (list or np.ndarray): Continuous data points. | |
time_axis (list or np.ndarray): Corresponding time points (timestamps or discrete steps). | |
window_size (int): Number of data points per bin for candlestick calculation. | |
**kwargs: Additional arguments for `plt.subplots()`. | |
""" | |
if len(data) != len(time_axis): | |
raise ValueError("`data` and `time_axis` must have the same length.") | |
# Convert input data to a DataFrame for easier processing | |
df = pd.DataFrame({"time": time_axis, "value": data}) | |
# Determine bin edges based on the window size | |
df["bin"] = df.index // window_size | |
# Aggregate data for each bin | |
candlestick_data = df.groupby("bin").agg( | |
open=("value", "first"), | |
close=("value", "last"), | |
high=("value", "max"), | |
low=("value", "min"), | |
start_time=("time", "first"), | |
end_time=("time", "last"), | |
) | |
# Plot the candlestick chart | |
kwargs.setdefault("figsize", (10, 6)) | |
fig, ax = plt.subplots(**kwargs) | |
for index, row in candlestick_data.iterrows(): | |
# Define candle color like vivarepublica | |
color = "red" if row["close"] >= row["open"] else "blue" | |
# Draw the candle body as a filled rectangle | |
rect_x = row["start_time"] | |
rect_width = (row["end_time"] - row["start_time"]) * 0.9 | |
rect_y = min(row["open"], row["close"]) | |
rect_height = abs(row["close"] - row["open"]) | |
ax.add_patch( | |
plt.Rectangle( | |
(rect_x, rect_y), rect_width, rect_height, color=color, alpha=1 | |
) | |
) | |
# Draw the high-low line | |
ax.vlines( | |
x=(row["start_time"] + row["end_time"]) / 2, | |
ymin=row["low"], | |
ymax=row["high"], | |
color=color, | |
linewidth=1, | |
) | |
return fig |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment