Created
January 9, 2024 23:05
-
-
Save IperGiove/325832e30df44639ec78a618b84daf3c to your computer and use it in GitHub Desktop.
baseline stable-baseline3 & backtesting.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import pandas as pd | |
from backtesting import Backtest, Strategy | |
from gymnasium import spaces, Env | |
from stable_baselines3 import PPO | |
import numpy as np | |
class periodicStrategy(Strategy): | |
def init(self): | |
print(f"Start with equity={self.equity:.2f}") | |
def next(self, action:int|None=None): | |
print(f"Action={action} Equity={self.equity:.2f} Date={self.data.index[-1]}") | |
if action: | |
if action == 1: | |
self.buy() | |
elif action == 2: | |
self.position.close() | |
def observation(self): | |
closes = self.data.Close[-20:] | |
closes = (closes - closes.min()) / (closes.max() - closes.min()) | |
return [closes] | |
class CustomEnv(Env): | |
"""Custom Environment that follows gym interface.""" | |
def __init__(self, bt: Backtest): | |
# observation (1,20) = (close price, 20 back days) | |
self.observation_space = spaces.Box(low=-1, high=1, shape=(1, 20), dtype=np.float32) | |
# action -1 sell all shares, 1 buy all shares for 1 crypto | |
self.action_space = spaces.Discrete(3) | |
self.bt = bt | |
def reward_calculation(self): | |
if self.previous_equity < self.bt._step_strategy.equity: | |
return +1 | |
return -1 | |
def check_done(self): | |
if self.bt._step_time + 2 > len(self.bt._data): | |
self.render() | |
return True | |
return False | |
def step(self, action): | |
obs = self.bt._step_strategy.observation() | |
reward = self.reward_calculation() | |
done = self.check_done() | |
info = {} | |
self.bt.next(action=action) | |
# False is done (never finish because the market can not finish) | |
# done is the truncate (the market can be truncated) | |
return obs, reward, False, done, info | |
def reset_backtesting(self): | |
# backtesting, give first next because when initialize can return the whole dataset | |
self.bt.initialize() | |
self.bt.next() | |
while True: | |
obs = self.bt._step_strategy.observation() | |
if np.shape(obs) == (1,20): | |
break | |
self.bt.next() | |
def reset(self, seed=None): | |
self.previous_equity = 10 | |
self.reset_backtesting() | |
return self.bt._step_strategy.observation(), {} | |
def render(self, mode='human'): | |
result = self.bt.next(done=True) | |
self.bt.plot(results=result, open_browser=False) | |
def close(self): | |
pass | |
def generate_sin_wave(periods: int = 1000, amplitude: float = 1.0) -> pd.DataFrame: | |
x = pd.date_range(start='2023-01-01', periods=periods, freq='D') | |
y = amplitude * pd.Series(data=np.sin(np.linspace(0, 10 * np.pi, periods)), index=x) + 2 | |
# Create a DataFrame with the required columns | |
data = pd.DataFrame({'Open': y, 'High': y, 'Low': y, 'Close': y, 'Volume': y}) | |
return data | |
data = generate_sin_wave() | |
print(data) | |
# Instantiate the env | |
bt = Backtest(data, periodicStrategy, cash=10) | |
env = CustomEnv(bt) | |
# env = VecNormalize(env) | |
# Define and Train the agent | |
model = PPO("MlpPolicy", env, verbose=0, tensorboard_log="./logs/") | |
model.learn(total_timesteps=1000000, log_interval=1) | |
# model.save("") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment