Created
August 31, 2023 01:02
-
-
Save Jiuh-star/2076fa0fc94e4bf4352357c200b2faa0 to your computer and use it in GitHub Desktop.
Mini backtest framework
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
Requirements: | |
python~=3.10 | |
numpy | |
optuna | |
loguru | |
""" | |
from __future__ import annotations | |
import datetime | |
import math | |
import time | |
import typing as T | |
from abc import ABC, abstractmethod | |
from collections import defaultdict | |
from dataclasses import dataclass, fields | |
import numpy as np | |
import optuna | |
import typing_extensions as TE | |
from loguru import logger | |
from numpy.typing import NDArray | |
_T = T.TypeVar("_T", bound="Table", covariant=True) | |
__all__ = [ | |
"Table", | |
"Market", | |
"MarketView", | |
"Order", | |
"OrderHelper", | |
"Broker", | |
"ParamRange", | |
"Strategy", | |
"Backtest", | |
] | |
@dataclass | |
class Table: | |
tag: str | |
date: NDArray[np.datetime64] | |
close: NDArray[np.float_] | |
def __post_init__(self): | |
self._len = len(self.date) | |
assert all(len(getattr(self, name)) == self._len for name in self.columns) | |
@property | |
def columns(self) -> list[str]: | |
return [f.name for f in fields(self) if f.name != "tag"] | |
def __len__(self): | |
return self._len | |
class Market(T.Generic[_T]): | |
tag_table: dict[str, _T] | |
def __init__(self, tables: T.Iterable[_T]) -> None: | |
dates = next(iter(tables)).date | |
assert np.all([dates == table.date for table in tables]) | |
self.tag_table = {table.tag: table for table in tables} | |
dates = dates.tolist() | |
dates = T.cast(list[datetime.date], dates) | |
self.dates: list[datetime.date] = dates | |
self.tags = list(self.tag_table.keys()) | |
self.first = MarketView(market=self, index=0) | |
self.last = MarketView(market=self, index=len(self) - 1) | |
self.current = self.first | |
@property | |
def is_end(self) -> bool: | |
return self.current.index >= len(self) - 1 | |
def reset(self): | |
self.current = self.first | |
def setup(self, backtest: Backtest): | |
self.reset() | |
def teardown(self, backtest: Backtest): | |
pass | |
def __getitem__(self, index): | |
return self.tag_table[index] | |
def __setitem__(self, index, value): | |
self.tag_table[index] = value | |
def __delitem__(self, index): | |
del self.tag_table[index] | |
def __iter__(self): | |
return self | |
def __next__(self): | |
if self.current.index >= len(self) - 1: | |
raise StopIteration | |
self.current = MarketView(market=self, index=self.current.index + 1) | |
return self.current | |
def __len__(self): | |
return len(self.dates) | |
def __repr__(self): | |
repr_params = { | |
"scale": len(self.tags), | |
"start_date": self.dates[0], | |
"stop_date": self.dates[-1], | |
"duration": len(self), | |
} | |
return f"{self.__class__.__name__}({', '.join(f'{k}={v}' for k, v in repr_params.items())})" | |
class MarketView: | |
historical_prefix = "historical_" | |
def __init__(self, market: Market, index: int): | |
self.market = market | |
self.index = index | |
self._cache = {} | |
@TE.overload | |
def get(self, tag: str, name: T.Literal["date"]) -> np.datetime64: | |
... | |
@TE.overload | |
def get(self, tag: str, name: str) -> float: | |
... | |
@TE.overload | |
def get(self, tag: str, name: str, *, history: T.Literal[False]) -> float: | |
... | |
@TE.overload | |
def get(self, tag: str, name: str, *, history: T.Literal[True]) -> NDArray[np.float_]: | |
... | |
@TE.overload | |
def get(self, tag: str, name: str, *, history: bool = False) -> float | NDArray[np.float_]: | |
... | |
def get(self, tag: str, name: str, *, history: bool = False): | |
line: NDArray = getattr(self.market[tag], name) | |
if history: | |
return line[: self.index] | |
return line[self.index] | |
@TE.overload | |
def get_all(self, name: T.Literal["date"]) -> dict[str, np.datetime64]: | |
... | |
@TE.overload | |
def get_all(self, name: str) -> dict[str, float]: | |
... | |
@TE.overload | |
def get_all(self, name: str, *, history: T.Literal[True]) -> dict[str, NDArray[np.float_]]: | |
... | |
@TE.overload | |
def get_all(self, name: str, *, history: T.Literal[False]) -> dict[str, float]: | |
... | |
@TE.overload | |
def get_all(self, name: str, *, history: bool = False) -> dict[str, float | NDArray[np.float_]]: | |
... | |
def get_all(self, name: str, *, history: bool = False): | |
if name not in self._cache: | |
self._cache[name] = {tag: self.get(tag, name, history=history) for tag in self.market.tags} | |
return self._cache[name] | |
@property | |
def date(self) -> datetime.date: | |
return self.market.dates[self.index] | |
def __getattr__(self, name: str): | |
if name.startswith(self.historical_prefix): | |
return self.get_all(name, history=True) | |
return self.get_all(name) | |
@property | |
def is_end(self) -> bool: | |
return self.index >= len(self.market) - 1 | |
@dataclass | |
class Order: | |
tag: str | |
size: int = 0 | |
price: float = 0.0 | |
start: int = 0 | |
stop: int = 0 | |
commission: float = 0.0 | |
status: T.Literal["open", "filled", "cancelled"] = "open" | |
_iota: T.ClassVar[int] = 0 | |
_orders: T.ClassVar[list[Order]] = [] | |
def __post_init__(self): | |
self.id = self._iota | |
self.__class__._iota += 1 | |
self.order_type = "limit" | |
self.time_in_force = "GTC" | |
self.__class__._orders.append(self) | |
@property | |
def is_long(self): | |
return self.size and self.size > 0 | |
@property | |
def is_short(self): | |
return self.size and self.size < 0 | |
def fill(self, index: int): | |
self.stop = index | |
self.status = "filled" | |
def cancel(self, index: int): | |
self.stop = index | |
self.status = "cancelled" | |
def __hash__(self): | |
return self.id | |
class OrderHelper: | |
@classmethod | |
def is_valid(cls, order: Order, broker: Broker) -> bool: | |
if order.status not in ["open"]: | |
return False | |
if order.tag not in broker.market.tags: | |
return False | |
if order.size == 0: | |
return False | |
if order.price <= 0: | |
return False | |
return True | |
@classmethod | |
def is_tradable(cls, order: Order, broker: Broker, *, skip_valid: bool = False) -> bool: | |
if not skip_valid and not cls.is_valid(order, broker): | |
return False | |
close = broker.market.current.get(order.tag, "close") | |
if order.is_long: | |
if order.price < close: | |
return False | |
if order.size > broker.cash * (1 - order.commission) / order.price: | |
return False | |
else: | |
if order.price > close: | |
return False | |
if abs(order.size) > broker.account[order.tag]: | |
return False | |
return True | |
@classmethod | |
def write_blank_order(cls, order: Order, broker: Broker, is_long: bool): | |
if order.price == 0.0: | |
order.price = broker.market.current.get(order.tag, "close") | |
if order.size == 0: | |
if is_long: | |
order.size = int(broker.cash * (1 - order.commission) / order.price) | |
else: | |
order.size = -int(broker.account[order.tag]) | |
if order.start == 0: | |
order.start = broker.market.current.index | |
@classmethod | |
def trade_an_order(cls, order: Order, broker: Broker, *, skip_check: bool = False) -> bool: | |
if not skip_check and not cls.is_tradable(order, broker): | |
return False | |
signed_value: float = order.size * order.price | |
commission = abs(signed_value * order.commission) | |
broker.account[broker.cash_name] -= signed_value + commission | |
broker.account[order.tag] += order.size | |
order.fill(broker.market.current.index) | |
broker.opening_orders.remove(order) | |
broker.filled_orders.append(order) | |
logger.info( | |
"[{}] {} {} | SIZE: {} | PRICE: {:.4f} | EQUITY: {:.4f}", | |
broker.market.current.date, | |
"BUY" if order.is_long else "SEL", | |
order.tag, | |
abs(order.size), | |
order.price, | |
broker.equity, | |
) | |
return True | |
class Broker: | |
cash_name = "cash" | |
def __init__( | |
self, | |
market: Market, | |
init_cash: float, | |
account: T.Optional[defaultdict[str, float]] = None, | |
order_helper: T.Optional[OrderHelper] = None, | |
): | |
self.market = market | |
self.init_account = account or defaultdict(float) | |
self.init_account[self.cash_name] = init_cash | |
self.account: defaultdict[str, float] = self.init_account.copy() | |
self.order_helper = order_helper or OrderHelper() | |
self.opening_orders: list[Order] = [] | |
self.cancelled_orders: list[Order] = [] | |
self.filled_orders: list[Order] = [] | |
@property | |
def orders(self) -> list[Order]: | |
return self.opening_orders + self.cancelled_orders + self.filled_orders | |
@property | |
def equity(self) -> float: | |
return self.cash + math.fsum( | |
asset * self.market.current.get(tag, "close") for tag, asset in self.assets.items() | |
) | |
@property | |
def cash(self) -> float: | |
return self.account[self.cash_name] | |
@property | |
def assets(self) -> dict[str, float]: | |
return {tag: asset for tag, asset in self.account.items() if tag != self.cash_name} | |
def trade(self): | |
opening_orders = list(self.opening_orders) | |
for order in opening_orders: | |
if order.start > self.market.current.index: | |
continue | |
is_done = self.order_helper.trade_an_order(order, self) | |
if is_done: | |
continue | |
if order.status == "open" and order.time_in_force == "GTC": | |
self.cancel_order(order) | |
def cancel_order(self, order: Order): | |
if order.status not in ["open"]: | |
return | |
if order in self.opening_orders: | |
self.opening_orders.remove(order) | |
order.cancel(self.market.current.index) | |
self.cancelled_orders.append(order) | |
def buy(self, order: Order) -> T.Optional[Order]: | |
self.order_helper.write_blank_order(order, self, is_long=True) | |
self.opening_orders.append(order) | |
return order | |
def sell(self, order: Order) -> T.Optional[Order]: | |
self.order_helper.write_blank_order(order, self, is_long=False) | |
self.opening_orders.append(order) | |
return order | |
def setup(self, backtest: Backtest): | |
self.account = self.init_account.copy() | |
def teardown(self, backtest: Backtest): | |
pass | |
def __repr__(self): | |
return f"{self.__class__.__name__}(cash={self.cash}, assets={self.assets}, equity={self.equity})" | |
class ParamRange(T.NamedTuple): | |
low: float | int | |
high: float | int | |
class Strategy(ABC): | |
trial_params: dict[str, ParamRange] = {} | |
def __init_subclass__(cls, **kwargs) -> None: | |
super().__init_subclass__(**kwargs) | |
cls.trial_params = cls.trial_params.copy() | |
def setup(self, backtest: Backtest): | |
pass | |
def teardown(self, backtest: Backtest): | |
pass | |
@abstractmethod | |
def on_point(self, broker: Broker, view: MarketView) -> T.Optional[dict[str, float]]: | |
raise NotImplementedError | |
class Backtest: | |
def __init__(self, *, market: Market, broker: Broker, strategy: Strategy) -> None: | |
self.market = market | |
self.broker = broker | |
self.strategy = strategy | |
self.indicators = defaultdict(list) | |
self.logbook = defaultdict(list) | |
def setup(self): | |
logger.debug("[START] Setting up environment") | |
self.logbook.clear() | |
self.market.setup(self) | |
self.broker.setup(self) | |
self.strategy.setup(self) | |
def teardown(self): | |
logger.debug("[END] Tearing down environment") | |
self.market.teardown(self) | |
self.broker.teardown(self) | |
self.strategy.teardown(self) | |
@np.errstate(invalid="ignore") | |
def run(self): | |
self.setup() | |
logger.info( | |
"[START] Backtest strategy {} from {} to {}", | |
self.strategy.__class__.__name__, | |
self.market.first.date, | |
self.market.last.date, | |
) | |
logger.info("[START] Equity: {}", self.broker.equity) | |
last_update = time.time() | |
for view in self.market: | |
if time.time() - last_update > 30: | |
logger.info( | |
"[{}] Finished: {:.2f}%", | |
self.market.current.date, | |
self.market.current.index / len(self.market) * 100, | |
) | |
last_update = time.time() | |
indicators = self.strategy.on_point(self.broker, view) | |
self.broker.trade() | |
self.logbook["equity"].append(self.broker.equity) | |
for tag in self.market.tags: | |
self.logbook[tag].append(self.broker.account[tag] if tag in self.broker.account else 0) | |
if not indicators: | |
continue | |
logger.debug("[{}] Indicators: {}", self.market.current.date, indicators) | |
for name, value in indicators.items(): | |
if name not in self.indicators: | |
self.indicators[name] = [0.0] * self.market.current.index | |
self.indicators[name].append(value) | |
logger.info("[END] Equity: {}", self.broker.equity) | |
self.teardown() | |
return self.broker.equity | |
@property | |
def equities(self): | |
return self.logbook["equity"] | |
def optimize(self, n_trials: int = 100, random_search: bool = False) -> optuna.Study: | |
trial_params = self.strategy.trial_params | |
def objective(trial: optuna.Trial) -> float: | |
# skip duplicate trials | |
for t in trial.study.trials: | |
if t.state != optuna.trial.TrialState.COMPLETE: | |
continue | |
if t.params == trial.params and t.value: | |
return t.value | |
# set trial params | |
for name, param_range in trial_params.items(): | |
if isinstance(param_range.low, int) and isinstance(param_range.high, int): | |
setattr( | |
self.strategy, | |
name, | |
trial.suggest_int(name, param_range.low, param_range.high), | |
) | |
else: | |
setattr( | |
self.strategy, | |
name, | |
trial.suggest_float(name, param_range.low, param_range.high), | |
) | |
logger.disable(__name__) | |
equity = self.run() | |
logger.enable(__name__) | |
roi = equity / self.logbook["equity"][0] - 1 | |
return roi | |
study = optuna.create_study( | |
direction="maximize", | |
sampler=optuna.samplers.RandomSampler() if random_search else None, | |
) | |
study.optimize(objective, n_trials=n_trials) | |
return study | |
def __repr__(self): | |
return f"{self.__class__.__name__}(market={self.market}, broker={self.broker}, strategy={self.strategy})" |
Author
Jiuh-star
commented
Aug 31, 2023
- This backtest framework make it possible to backtest a stock portfolio that more than 500, while keep fast (<500ms on MACD).
- Lack of data loader, analyze tool and techniqual indicator, but it should be easy to implement one since there are better third-party alternatives (yfinance, quantstats, ta-lib-python, empyrical)
- It's recommended to use quantstats and mplfinance to analyze and visualize.
- Why not backtrader? backtrader is good, but meta programming in Python is bad.
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment