Last active
April 25, 2020 12:06
-
-
Save yoonbae81/7e31630a7099ceb7d3d29d114283e036 to your computer and use it in GitHub Desktop.
Message passing prototype for backtest
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import logging | |
from collections import defaultdict | |
from dataclasses import dataclass | |
from multiprocessing import Pipe, Process, Queue | |
from multiprocessing.connection import Connection, wait | |
from os import cpu_count | |
from random import randint | |
from threading import Thread | |
from time import sleep | |
from typing import Any, Callable, DefaultDict, Dict, List, Optional, Set, TypeVar | |
@dataclass | |
class Msg: | |
type: str = '' | |
symbol: str = '' | |
market: str = '' | |
price: float = 0 | |
quantity: float = 0 | |
strength: int = 0 | |
timestamp: int = 0 | |
class Fetcher(Thread): | |
def __init__(self) -> None: | |
super().__init__(name=self.__class__.__name__) | |
self.output: Connection | |
self.logger = logging.getLogger(self.__class__.__name__) | |
def run(self) -> None: | |
sleep(0.2) | |
self.logger.debug('Running') | |
for i in range(1, 10): | |
msg = Msg('TICK', f's{i}') | |
self.output.send(msg) | |
print(f'Fetcher sent: {msg}') | |
if i % 30 == 0: | |
self.output.send(Msg('EOF')) | |
sleep(0.2) | |
self.output.send(Msg('EOD')) | |
class Analyzer(Process): | |
count: int = 0 | |
def __init__(self) -> None: | |
self.__class__.count += 1 | |
super().__init__(name=self.__class__.__name__ + str(self.__class__.count)) | |
self.input: Connection | |
self.output: Connection | |
# Event loop | |
self._running: bool = False | |
self._handlers: Dict[str, Callable[[Msg], None]] = { | |
'TICK': self._handler_tick, | |
'POSITION': self._handler_position, | |
'RESET': self._handler_reset, | |
'QUIT': self._handler_quit, | |
} | |
self._positions: Set[str] = set() | |
def run(self) -> None: | |
self._running = True | |
while self._running: | |
msg = self.input.recv() | |
print(f'{self.name} received: {msg}') | |
self._handlers[msg.type](msg) | |
def _handler_tick(self, msg: Msg) -> None: | |
msg.type = 'SIGNAL' | |
msg.strength = randint(-10, 10) | |
self.output.send(msg) | |
def _handler_position(self, msg: Msg) -> None: | |
if msg.quantity == 0 and msg.symbol in self._positions: | |
self._positions.remove(msg.symbol) | |
else: | |
self._positions.add(msg.symbol) | |
def _handler_quit(self, msg: Msg) -> None: | |
self._running = False | |
def _handler_reset(self, msg: Msg) -> None: | |
pass | |
class Broker(Thread): | |
def __init__(self) -> None: | |
super().__init__(name=self.__class__.__name__) | |
self.input: Connection | |
self.output: Connection | |
# Event loop | |
self._running: bool = False | |
self._handlers: Dict[str, Callable[[Msg], None]] = { | |
'SIGNAL': self._handler_signal, | |
'QUIT': self._handler_quit, | |
} | |
def run(self): | |
self._running = True | |
while self._running: | |
msg = self.input.recv() | |
print(f'{self.name} received: {msg}') | |
self._handlers[msg.type](msg) | |
def _handler_signal(self, msg: Msg) -> None: | |
msg.type = 'POSITION' | |
self.output.send(msg) | |
def _handler_quit(self, msg: Msg) -> None: | |
self._running = False | |
Node = TypeVar('Node', Fetcher, Analyzer, Broker) | |
class Router(Thread): | |
def __init__(self) -> None: | |
self._name = self.__class__.__name__ | |
super().__init__(name=self._name) | |
# Event loop | |
self._running = True | |
# Fetcher | |
self._from_fetcher: Connection | |
# Broker | |
self._from_broker: Connection | |
self._to_broker: Connection | |
# Analyzer | |
self._from_analyzers: List[Connection] = [] | |
self._to_analyzers: List[Connection] = [] | |
self._analyzer_counter: Dict[Connection, int] = {} | |
self._analyzer_assigned: Dict[str, Connection] = {} | |
self._handlers: Dict[str, Callable[[Msg], None]] = { | |
'TICK': self._handler_tick, | |
'SIGNAL': self._handler_signal, | |
'POSITION': self._handler_position, | |
'EOF': self._handler_eof, | |
'EOD': self._handler_eod, | |
} | |
self._msg_counter: DefaultDict[str, int] = defaultdict(int) | |
def connect(self, node: Node) -> bool: | |
if isinstance(node, Analyzer): | |
from_analyzer, node.output = Pipe(duplex=False) | |
self._from_analyzers.append(from_analyzer) | |
node.input, to_analyzer = Pipe(duplex=False) | |
self._to_analyzers.append(to_analyzer) | |
self._analyzer_counter[to_analyzer] = 0 | |
elif isinstance(node, Fetcher): | |
# Fetcher does not need to listen from Router | |
self._from_fetcher, node.output = Pipe(duplex=False) | |
elif isinstance(node, Broker): | |
self._from_broker, node.output = Pipe(duplex=False) | |
node.input, self._to_broker = Pipe(duplex=False) | |
else: | |
raise TypeError(node) | |
print(f'{node.name} connected') | |
return True | |
def run(self): | |
while self._running: | |
for conn in wait([*self._from_analyzers, | |
self._from_broker, | |
self._from_fetcher], | |
timeout=1): | |
msg = conn.recv() | |
self._msg_counter[msg.type] += 1 | |
# print(f'{self.name} received: {msg}') | |
try: | |
self._handlers[msg.type](msg) | |
except KeyError: | |
print('Unknown message ', msg) | |
print(self._msg_counter) | |
def _get_analyzer(self, symbol: str) -> Connection: | |
try: | |
to_analyzer = self._analyzer_assigned[symbol] | |
except KeyError: | |
to_analyzer = min(self._analyzer_counter, | |
key=self._analyzer_counter.get) | |
self._analyzer_assigned[symbol] = to_analyzer | |
self._analyzer_counter[to_analyzer] += 1 | |
return to_analyzer | |
def _handler_tick(self, msg: Msg) -> None: | |
to_analyzer = self._get_analyzer(msg.symbol) | |
to_analyzer.send(msg) | |
def _handler_signal(self, msg: Msg) -> None: | |
self._to_broker.send(msg) | |
def _handler_position(self, msg: Msg) -> None: | |
to_analyzer = self._get_analyzer(msg.symbol) | |
to_analyzer.send(msg) | |
def _handler_eof(self, msg: Msg) -> None: | |
for to_analyzer in self._to_analyzers: | |
to_analyzer.send(Msg('RESET')) | |
def _handler_eod(self, msg: Msg) -> None: | |
for node in [*self._to_analyzers, self._to_broker]: | |
node.send(Msg('QUIT')) | |
self._running = False | |
if __name__ == '__main__': | |
fetcher = Fetcher() | |
analyzers = [Analyzer() for _ in range((cpu_count() or 2) - 1)] | |
broker = Broker() | |
router = Router() | |
nodes: List[Any] = [broker, *analyzers, fetcher] | |
[router.connect(node) for node in nodes] | |
nodes.insert(0, router) | |
[node.start() for node in nodes] | |
[node.join() for node in reversed(nodes)] |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment