Skip to content

Instantly share code, notes, and snippets.

@jwiegley
Created September 20, 2024 20:10
Show Gist options
  • Save jwiegley/ffdcf3a40158d79b38de807e318b743b to your computer and use it in GitHub Desktop.
Save jwiegley/ffdcf3a40158d79b38de807e318b743b to your computer and use it in GitHub Desktop.
# This Python code translates the Haskell code into idiomatic Python, using
# features like dataclasses, pattern matching (using the =match= statement
# introduced in Python 3.10), and type hints.
#
# The main changes made in the Python version are:
#
# 1. Using dataclasses instead of Haskell's record syntax for defining data
# types.
# 2. Implementing algebraic data types using Python's class inheritance and
# the =|= operator for union types.
# 3. Using Python's =match= statement for pattern matching, which is similar
# to Haskell's case expressions.
# 4. Defining helper functions like =break_=, =survey=, and =MkZipper= to
# mimic the behavior of Haskell's standard library functions and types.
#
# Note that this code requires Python 3.10 or later due to the use of the
# =match= statement for pattern matching. If you are using an earlier version
# of Python, you may need to adjust the code accordingly.
from __future__ import annotations
from dataclasses import dataclass
from datetime import datetime, timedelta
from typing import Callable, List, Dict, Optional, Tuple
@dataclass
class TimePrice:
price: float
time: datetime
@dataclass
class Lot:
lot_amount: float
lot_detail: TimePrice
@dataclass
class Open:
open_lot: Lot
open_basis: Optional[float] = None
@dataclass
class Closed:
closing_lot: Lot
closing_detail: TimePrice
closing_washable: bool
Position = Open | Closed
Strategy = Callable[[List[Position]], List[Position]]
class LotChange:
pass
@dataclass
class NoChange(LotChange):
pass
@dataclass
class AddLot(LotChange):
lot: Lot
@dataclass
class ReduceLot(LotChange):
lot: Lot
@dataclass
class ReplaceLot(LotChange):
lot: Lot
def add_lot(x: Lot, y: Lot) -> LotChange:
xn, xd = x.lot_amount, x.lot_detail
yn, yd = y.lot_amount, y.lot_detail
if (xn > 0 and yn < 0) or (xn < 0 and yn > 0):
if abs(xn) >= abs(yn):
return ReduceLot(Lot(xn + yn, xd))
else:
return ReplaceLot(Lot(xn + yn, yd))
elif xd == yd:
return AddLot(Lot(xn + yn, xd))
else:
return NoChange()
def add_to_positions(strategy: Strategy, x: Lot, positions: List[Position]) -> List[Position]:
def go(y: Lot, ps: List[Position]) -> List[Position]:
if not ps:
return [Open(y)]
match ps[0]:
case Open(z, wash):
match add_lot(z, y):
case NoChange():
return [Open(z, wash)] + go(y, ps[1:])
case AddLot(w):
return [Open(w, wash)] + ps[1:]
case ReduceLot(Lot(zn_prime, zd_prime)):
zn = z.lot_amount
zd_prime = z.lot_detail
yd = y.lot_detail
return [Open(Lot(zn_prime, zd_prime), wash), Closed(Lot(zn - zn_prime, zd_prime), yd, True)] + ps[1:]
case ReplaceLot(w):
yd = y.lot_detail
return [Closed(z, yd, True)] + go(w, ps[1:])
case _:
return [ps[0]] + go(y, ps[1:])
return strategy(go(x, strategy(positions)))
def identify_trades(strategy: Strategy, positions: Dict[str, List[Position]], trades: List[Tuple[str, Lot]]) -> Dict[str, List[Position]]:
def identify_trade(m: Dict[str, List[Position]], sym: str, lot: Lot) -> Dict[str, List[Position]]:
def go(ps: Optional[List[Position]]) -> List[Position]:
if ps is None:
return [Open(lot)]
else:
return add_to_positions(strategy, lot, ps)
return {**m, sym: go(m.get(sym))}
return dict(functools.reduce(lambda m, t: identify_trade(m, *t), trades, positions))
def wash_sales(positions: List[Position]) -> List[Position]:
def go(z: Zipper[Position]) -> Zipper[Position]:
match z:
case MkZipper(before, event@Closed(l@Lot(n, TimePrice(b, d)), pd@TimePrice(p, _), True), after) if eligible_losing_close(event):
xpre, x, xpost = break_(lambda x: eligible_new_open(d, x), before)
match x:
case Open(x, None):
return MkZipper(xpre + [adjusted(x)] + xpost, closed, after)
case _:
ypre, y, ypost = break_(lambda y: eligible_new_open(d, y), after)
match y:
case Open(y, None):
return MkZipper(before, closed, ypre + [adjusted(y)] + ypost)
case _:
return just_closed
case _:
return just_closed
where:
total_loss = n * (b - p)
closed = Closed(l, pd, False)
just_closed = MkZipper(before, closed, after)
adjusted(x@Lot(m, TimePrice(o, _))) = Open(x, o + total_loss / m)
def eligible_losing_close(event: Position) -> bool:
match event:
case Closed(Lot(n, TimePrice(p, _)), TimePrice(p_prime, _), True):
return n * (p_prime - p) < 0
case _:
return False
def eligible_new_open(d: datetime, pos: Position) -> bool:
match pos:
case Open(Lot(_, TimePrice(_, d_prime)), None):
return d != d_prime and within_days(30, d, d_prime)
case _:
return False
return survey(go, positions)
def within_days(days: int, x: datetime, y: datetime) -> bool:
return abs((x - y).total_seconds()) < days * 86400
@dataclass
class MkZipper(Generic[T]):
before: List[T]
focus: T
after: List[T]
def break_(predicate: Callable[[T], bool], xs: List[T]) -> Tuple[List[T], Optional[T], List[T]]:
for i, x in enumerate(xs):
if predicate(x):
return xs[:i], x, xs[i+1:]
return xs, None, []
Zipper = MkZipper
def survey(f: Callable[[Zipper[T]], Zipper[T]], xs: List[T]) -> List[T]:
def go(before: List[T], focus: T, after: List[T]) -> Tuple[List[T], Optional[T], List[T]]:
z = MkZipper(before, focus, after)
z_prime = f(z)
return z_prime.before, z_prime.focus, z_prime.after
before, focus, after = functools.reduce(lambda a, c: go(*a, c), xs[1:], ([], xs[0], []))
return before + [focus] + after
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment