Skip to content

Instantly share code, notes, and snippets.

@Jiuh-star
Created August 31, 2023 01:02
Show Gist options
  • Save Jiuh-star/2076fa0fc94e4bf4352357c200b2faa0 to your computer and use it in GitHub Desktop.
Save Jiuh-star/2076fa0fc94e4bf4352357c200b2faa0 to your computer and use it in GitHub Desktop.
Mini backtest framework
"""
Requirements:
python~=3.10
numpy
optuna
loguru
"""
from __future__ import annotations
import datetime
import math
import time
import typing as T
from abc import ABC, abstractmethod
from collections import defaultdict
from dataclasses import dataclass, fields
import numpy as np
import optuna
import typing_extensions as TE
from loguru import logger
from numpy.typing import NDArray
_T = T.TypeVar("_T", bound="Table", covariant=True)
__all__ = [
"Table",
"Market",
"MarketView",
"Order",
"OrderHelper",
"Broker",
"ParamRange",
"Strategy",
"Backtest",
]
@dataclass
class Table:
tag: str
date: NDArray[np.datetime64]
close: NDArray[np.float_]
def __post_init__(self):
self._len = len(self.date)
assert all(len(getattr(self, name)) == self._len for name in self.columns)
@property
def columns(self) -> list[str]:
return [f.name for f in fields(self) if f.name != "tag"]
def __len__(self):
return self._len
class Market(T.Generic[_T]):
tag_table: dict[str, _T]
def __init__(self, tables: T.Iterable[_T]) -> None:
dates = next(iter(tables)).date
assert np.all([dates == table.date for table in tables])
self.tag_table = {table.tag: table for table in tables}
dates = dates.tolist()
dates = T.cast(list[datetime.date], dates)
self.dates: list[datetime.date] = dates
self.tags = list(self.tag_table.keys())
self.first = MarketView(market=self, index=0)
self.last = MarketView(market=self, index=len(self) - 1)
self.current = self.first
@property
def is_end(self) -> bool:
return self.current.index >= len(self) - 1
def reset(self):
self.current = self.first
def setup(self, backtest: Backtest):
self.reset()
def teardown(self, backtest: Backtest):
pass
def __getitem__(self, index):
return self.tag_table[index]
def __setitem__(self, index, value):
self.tag_table[index] = value
def __delitem__(self, index):
del self.tag_table[index]
def __iter__(self):
return self
def __next__(self):
if self.current.index >= len(self) - 1:
raise StopIteration
self.current = MarketView(market=self, index=self.current.index + 1)
return self.current
def __len__(self):
return len(self.dates)
def __repr__(self):
repr_params = {
"scale": len(self.tags),
"start_date": self.dates[0],
"stop_date": self.dates[-1],
"duration": len(self),
}
return f"{self.__class__.__name__}({', '.join(f'{k}={v}' for k, v in repr_params.items())})"
class MarketView:
historical_prefix = "historical_"
def __init__(self, market: Market, index: int):
self.market = market
self.index = index
self._cache = {}
@TE.overload
def get(self, tag: str, name: T.Literal["date"]) -> np.datetime64:
...
@TE.overload
def get(self, tag: str, name: str) -> float:
...
@TE.overload
def get(self, tag: str, name: str, *, history: T.Literal[False]) -> float:
...
@TE.overload
def get(self, tag: str, name: str, *, history: T.Literal[True]) -> NDArray[np.float_]:
...
@TE.overload
def get(self, tag: str, name: str, *, history: bool = False) -> float | NDArray[np.float_]:
...
def get(self, tag: str, name: str, *, history: bool = False):
line: NDArray = getattr(self.market[tag], name)
if history:
return line[: self.index]
return line[self.index]
@TE.overload
def get_all(self, name: T.Literal["date"]) -> dict[str, np.datetime64]:
...
@TE.overload
def get_all(self, name: str) -> dict[str, float]:
...
@TE.overload
def get_all(self, name: str, *, history: T.Literal[True]) -> dict[str, NDArray[np.float_]]:
...
@TE.overload
def get_all(self, name: str, *, history: T.Literal[False]) -> dict[str, float]:
...
@TE.overload
def get_all(self, name: str, *, history: bool = False) -> dict[str, float | NDArray[np.float_]]:
...
def get_all(self, name: str, *, history: bool = False):
if name not in self._cache:
self._cache[name] = {tag: self.get(tag, name, history=history) for tag in self.market.tags}
return self._cache[name]
@property
def date(self) -> datetime.date:
return self.market.dates[self.index]
def __getattr__(self, name: str):
if name.startswith(self.historical_prefix):
return self.get_all(name, history=True)
return self.get_all(name)
@property
def is_end(self) -> bool:
return self.index >= len(self.market) - 1
@dataclass
class Order:
tag: str
size: int = 0
price: float = 0.0
start: int = 0
stop: int = 0
commission: float = 0.0
status: T.Literal["open", "filled", "cancelled"] = "open"
_iota: T.ClassVar[int] = 0
_orders: T.ClassVar[list[Order]] = []
def __post_init__(self):
self.id = self._iota
self.__class__._iota += 1
self.order_type = "limit"
self.time_in_force = "GTC"
self.__class__._orders.append(self)
@property
def is_long(self):
return self.size and self.size > 0
@property
def is_short(self):
return self.size and self.size < 0
def fill(self, index: int):
self.stop = index
self.status = "filled"
def cancel(self, index: int):
self.stop = index
self.status = "cancelled"
def __hash__(self):
return self.id
class OrderHelper:
@classmethod
def is_valid(cls, order: Order, broker: Broker) -> bool:
if order.status not in ["open"]:
return False
if order.tag not in broker.market.tags:
return False
if order.size == 0:
return False
if order.price <= 0:
return False
return True
@classmethod
def is_tradable(cls, order: Order, broker: Broker, *, skip_valid: bool = False) -> bool:
if not skip_valid and not cls.is_valid(order, broker):
return False
close = broker.market.current.get(order.tag, "close")
if order.is_long:
if order.price < close:
return False
if order.size > broker.cash * (1 - order.commission) / order.price:
return False
else:
if order.price > close:
return False
if abs(order.size) > broker.account[order.tag]:
return False
return True
@classmethod
def write_blank_order(cls, order: Order, broker: Broker, is_long: bool):
if order.price == 0.0:
order.price = broker.market.current.get(order.tag, "close")
if order.size == 0:
if is_long:
order.size = int(broker.cash * (1 - order.commission) / order.price)
else:
order.size = -int(broker.account[order.tag])
if order.start == 0:
order.start = broker.market.current.index
@classmethod
def trade_an_order(cls, order: Order, broker: Broker, *, skip_check: bool = False) -> bool:
if not skip_check and not cls.is_tradable(order, broker):
return False
signed_value: float = order.size * order.price
commission = abs(signed_value * order.commission)
broker.account[broker.cash_name] -= signed_value + commission
broker.account[order.tag] += order.size
order.fill(broker.market.current.index)
broker.opening_orders.remove(order)
broker.filled_orders.append(order)
logger.info(
"[{}] {} {} | SIZE: {} | PRICE: {:.4f} | EQUITY: {:.4f}",
broker.market.current.date,
"BUY" if order.is_long else "SEL",
order.tag,
abs(order.size),
order.price,
broker.equity,
)
return True
class Broker:
cash_name = "cash"
def __init__(
self,
market: Market,
init_cash: float,
account: T.Optional[defaultdict[str, float]] = None,
order_helper: T.Optional[OrderHelper] = None,
):
self.market = market
self.init_account = account or defaultdict(float)
self.init_account[self.cash_name] = init_cash
self.account: defaultdict[str, float] = self.init_account.copy()
self.order_helper = order_helper or OrderHelper()
self.opening_orders: list[Order] = []
self.cancelled_orders: list[Order] = []
self.filled_orders: list[Order] = []
@property
def orders(self) -> list[Order]:
return self.opening_orders + self.cancelled_orders + self.filled_orders
@property
def equity(self) -> float:
return self.cash + math.fsum(
asset * self.market.current.get(tag, "close") for tag, asset in self.assets.items()
)
@property
def cash(self) -> float:
return self.account[self.cash_name]
@property
def assets(self) -> dict[str, float]:
return {tag: asset for tag, asset in self.account.items() if tag != self.cash_name}
def trade(self):
opening_orders = list(self.opening_orders)
for order in opening_orders:
if order.start > self.market.current.index:
continue
is_done = self.order_helper.trade_an_order(order, self)
if is_done:
continue
if order.status == "open" and order.time_in_force == "GTC":
self.cancel_order(order)
def cancel_order(self, order: Order):
if order.status not in ["open"]:
return
if order in self.opening_orders:
self.opening_orders.remove(order)
order.cancel(self.market.current.index)
self.cancelled_orders.append(order)
def buy(self, order: Order) -> T.Optional[Order]:
self.order_helper.write_blank_order(order, self, is_long=True)
self.opening_orders.append(order)
return order
def sell(self, order: Order) -> T.Optional[Order]:
self.order_helper.write_blank_order(order, self, is_long=False)
self.opening_orders.append(order)
return order
def setup(self, backtest: Backtest):
self.account = self.init_account.copy()
def teardown(self, backtest: Backtest):
pass
def __repr__(self):
return f"{self.__class__.__name__}(cash={self.cash}, assets={self.assets}, equity={self.equity})"
class ParamRange(T.NamedTuple):
low: float | int
high: float | int
class Strategy(ABC):
trial_params: dict[str, ParamRange] = {}
def __init_subclass__(cls, **kwargs) -> None:
super().__init_subclass__(**kwargs)
cls.trial_params = cls.trial_params.copy()
def setup(self, backtest: Backtest):
pass
def teardown(self, backtest: Backtest):
pass
@abstractmethod
def on_point(self, broker: Broker, view: MarketView) -> T.Optional[dict[str, float]]:
raise NotImplementedError
class Backtest:
def __init__(self, *, market: Market, broker: Broker, strategy: Strategy) -> None:
self.market = market
self.broker = broker
self.strategy = strategy
self.indicators = defaultdict(list)
self.logbook = defaultdict(list)
def setup(self):
logger.debug("[START] Setting up environment")
self.logbook.clear()
self.market.setup(self)
self.broker.setup(self)
self.strategy.setup(self)
def teardown(self):
logger.debug("[END] Tearing down environment")
self.market.teardown(self)
self.broker.teardown(self)
self.strategy.teardown(self)
@np.errstate(invalid="ignore")
def run(self):
self.setup()
logger.info(
"[START] Backtest strategy {} from {} to {}",
self.strategy.__class__.__name__,
self.market.first.date,
self.market.last.date,
)
logger.info("[START] Equity: {}", self.broker.equity)
last_update = time.time()
for view in self.market:
if time.time() - last_update > 30:
logger.info(
"[{}] Finished: {:.2f}%",
self.market.current.date,
self.market.current.index / len(self.market) * 100,
)
last_update = time.time()
indicators = self.strategy.on_point(self.broker, view)
self.broker.trade()
self.logbook["equity"].append(self.broker.equity)
for tag in self.market.tags:
self.logbook[tag].append(self.broker.account[tag] if tag in self.broker.account else 0)
if not indicators:
continue
logger.debug("[{}] Indicators: {}", self.market.current.date, indicators)
for name, value in indicators.items():
if name not in self.indicators:
self.indicators[name] = [0.0] * self.market.current.index
self.indicators[name].append(value)
logger.info("[END] Equity: {}", self.broker.equity)
self.teardown()
return self.broker.equity
@property
def equities(self):
return self.logbook["equity"]
def optimize(self, n_trials: int = 100, random_search: bool = False) -> optuna.Study:
trial_params = self.strategy.trial_params
def objective(trial: optuna.Trial) -> float:
# skip duplicate trials
for t in trial.study.trials:
if t.state != optuna.trial.TrialState.COMPLETE:
continue
if t.params == trial.params and t.value:
return t.value
# set trial params
for name, param_range in trial_params.items():
if isinstance(param_range.low, int) and isinstance(param_range.high, int):
setattr(
self.strategy,
name,
trial.suggest_int(name, param_range.low, param_range.high),
)
else:
setattr(
self.strategy,
name,
trial.suggest_float(name, param_range.low, param_range.high),
)
logger.disable(__name__)
equity = self.run()
logger.enable(__name__)
roi = equity / self.logbook["equity"][0] - 1
return roi
study = optuna.create_study(
direction="maximize",
sampler=optuna.samplers.RandomSampler() if random_search else None,
)
study.optimize(objective, n_trials=n_trials)
return study
def __repr__(self):
return f"{self.__class__.__name__}(market={self.market}, broker={self.broker}, strategy={self.strategy})"
@Jiuh-star
Copy link
Author

  • This backtest framework make it possible to backtest a stock portfolio that more than 500, while keep fast (<500ms on MACD).
  • Lack of data loader, analyze tool and techniqual indicator, but it should be easy to implement one since there are better third-party alternatives (yfinance, quantstats, ta-lib-python, empyrical)
  • It's recommended to use quantstats and mplfinance to analyze and visualize.
  • Why not backtrader? backtrader is good, but meta programming in Python is bad.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment