Skip to content

Instantly share code, notes, and snippets.

@tuttlem
Created September 12, 2025 23:25
Show Gist options
  • Save tuttlem/39bd51e7560dc4523ae69b1dcaefa6b1 to your computer and use it in GitHub Desktop.
Save tuttlem/39bd51e7560dc4523ae69b1dcaefa6b1 to your computer and use it in GitHub Desktop.
from __future__ import annotations
from dataclasses import dataclass
from typing import (
Any,
Callable,
Generic,
Iterable,
Iterator,
Optional,
TypeVar,
overload,
Union,
cast,
Generator
)
import inspect
T = TypeVar("T")
U = TypeVar("U")
V = TypeVar("V")
# ---------------------------
# Core ADT
# ---------------------------
class Maybe(Generic[T]):
"""
Minimal laws-friendly Maybe monad / applicative / functor.
Subclasses:
- Some(value)
- Nothing()
Operations (total-safe; never throw unless using unsafe unwrap()):
- map(f): Functor map
- bind(f): Monad bind / flat_map, where f: T -> Maybe[U]
- ap(mf): Applicative apply: Maybe[Callable[[T], U]] -> Maybe[U]
- filter(p): Keep the value if predicate holds, else Nothing
- get_or_else(x) Return contained value or default x
- or_else(mb): Return self if Some else provided Maybe
- match(...): Pattern matching helper returning a value
- to_iter(): Iterate 0 or 1 elements
"""
# -- construction helpers --
@staticmethod
def some(value: T) -> "Maybe[T]":
return Some(value)
@staticmethod
def nothing() -> "Maybe[T]":
return NOTHING # singleton
@staticmethod
def from_nullable(value: Optional[T]) -> "Maybe[T]":
return Nothing() if value is None else Some(value)
@staticmethod
def from_predicate(value: T, predicate: Callable[[T], bool]) -> "Maybe[T]":
return Some(value) if predicate(value) else Nothing()
# -- core algebra --
def map(self, f: Callable[[T], U]) -> "Maybe[U]":
raise NotImplementedError
def bind(self, f: Callable[[T], "Maybe[U]"]) -> "Maybe[U]":
raise NotImplementedError
# alias for Haskell/Python communities
flat_map = bind
def ap(self: "Maybe[Callable[[T], U]]", mb: "Maybe[T]") -> "Maybe[U]":
raise NotImplementedError
def filter(self, predicate: Callable[[T], bool]) -> "Maybe[T]":
raise NotImplementedError
# -- extraction/combination --
def get_or_else(self, default: U) -> Union[T, U]:
raise NotImplementedError
def or_else(self, other: "Maybe[T]") -> "Maybe[T]":
raise NotImplementedError
# -- utilities --
def is_some(self) -> bool:
return isinstance(self, Some)
def is_nothing(self) -> bool:
return isinstance(self, Nothing)
@overload
def match(self, *, some: Callable[[T], U], nothing: Callable[[], U]) -> U: ...
@overload
def match(self, *, some: Callable[[T], U], default: U) -> U: ...
def match(self, **kwargs: Any) -> Any:
"""
match(
some=lambda v: ...,
nothing=lambda: ... # or default=...
)
"""
if isinstance(self, Some):
fn = kwargs.get("some")
if fn is None:
raise TypeError("match() requires a 'some' callable")
return fn(self.value)
else:
if "nothing" in kwargs:
return kwargs["nothing"]()
if "default" in kwargs:
return kwargs["default"]
raise TypeError("match() requires 'nothing' or 'default' when Nothing")
def to_iter(self) -> Iterator[T]:
return iter(()) if self.is_nothing() else iter((cast(Some[T], self).value,))
def __iter__(self) -> Iterator[T]:
return self.to_iter()
def __bool__(self) -> bool:
# Truthy if Some, falsy if Nothing
return self.is_some()
# Unsafe: raises on Nothing (handy in scripts/tests; avoid in prod)
def unwrap(self) -> T:
if isinstance(self, Some):
return self.value
raise ValueError("Tried to unwrap Nothing")
@dataclass(frozen=True)
class Some(Maybe[T]):
value: T
def map(self, f: Callable[[T], U]) -> "Maybe[U]":
try:
return Some(f(self.value))
except Exception:
# If you prefer raising, change this behavior; Maybe often absorbs.
return Nothing()
def bind(self, f: Callable[[T], "Maybe[U]"]) -> "Maybe[U]":
try:
return f(self.value)
except Exception:
return Nothing()
def ap(self: "Maybe[Callable[[T], U]]", mb: "Maybe[T]") -> "Maybe[U]":
func = cast(Some[Callable[[T], U]], self).value
return mb.map(func)
def filter(self, predicate: Callable[[T], bool]) -> "Maybe[T]":
try:
return self if predicate(self.value) else Nothing()
except Exception:
return Nothing()
def get_or_else(self, default: U) -> Union[T, U]:
return self.value
def or_else(self, other: "Maybe[T]") -> "Maybe[T]":
return self
def __repr__(self) -> str:
return f"Some({self.value!r})"
class Nothing(Maybe[Any]):
__slots__ = ()
def __new__(cls) -> "Nothing":
# singleton
if not hasattr(cls, "_instance"):
cls._instance = super().__new__(cls)
return cls._instance # type: ignore[attr-defined]
# -- core algebra --
def map(self, f: Callable[[Any], U]) -> "Maybe[U]":
return self
def bind(self, f: Callable[[Any], "Maybe[U]"]) -> "Maybe[U]":
return self
def ap(self: "Maybe[Callable[[Any], U]]", mb: "Maybe[Any]") -> "Maybe[U]":
return self
def filter(self, predicate: Callable[[Any], bool]) -> "Maybe[Any]":
return self
# -- extraction/combination --
def get_or_else(self, default: U) -> U:
return default
def or_else(self, other: "Maybe[T]") -> "Maybe[T]":
return other
# -- dunder --
def __repr__(self) -> str:
return "Nothing()"
def __reduce__(self):
# ensure pickling returns the singleton
return (Nothing, ())
# Singleton instance for Nothing
NOTHING: Nothing = Nothing()
# ---------------------------
# Helpers / Lifts
# ---------------------------
def lift(f: Callable[[T], U]) -> Callable[[Maybe[T]], Maybe[U]]:
def _inner(m: Maybe[T]) -> Maybe[U]:
return m.map(f)
return _inner
def lift2(f: Callable[[T, U], V]) -> Callable[[Maybe[T], Maybe[U]], Maybe[V]]:
def _inner(ma: Maybe[T], mb: Maybe[U]) -> Maybe[V]:
return Some(lambda a: lambda b: f(a, b)).ap(ma).ap(mb) # applicative style
return _inner
def sequence(iterable: Iterable[Maybe[T]]) -> Maybe[list[T]]:
"""
Turn Iterable[Maybe[T]] into Maybe[list[T]].
Fails fast to Nothing if any element is Nothing.
"""
acc: list[T] = []
for m in iterable:
if m.is_nothing():
return NOTHING
acc.append(cast(Some[T], m).value)
return Some(acc)
def traverse(iterable: Iterable[T], f: Callable[[T], Maybe[U]]) -> Maybe[list[U]]:
"""
Map with f: T -> Maybe[U] and sequence.
"""
return sequence(f(x) for x in iterable)
# ---- maybe_do: generator-based "do notation" for Maybe ----
def maybe_do(fn: Callable[..., Generator["Maybe[T]", T, U]]):
"""
Decorate a generator function that yields Maybe values.
Usage:
@maybe_do
def program(...):
x = yield Some(10)
y = yield half_if_even(x) # returns Maybe[int]
z = yield Maybe.from_predicate(y, lambda n: n > 2)
return x + y + z # becomes Some(result) or Nothing()
Semantics:
- On each `yield <Maybe>`, if it's Some(v), `v` is sent back into the generator.
- If it's Nothing, the whole computation short-circuits to Nothing().
- On normal completion, wraps the returned value in Some(...).
"""
def wrapper(*args, **kwargs) -> "Maybe[U]":
gen = fn(*args, **kwargs)
if not inspect.isgenerator(gen):
# If the function isn't a generator, treat as pure and lift to Some
return Some(cast(U, gen)) # type: ignore[arg-type]
try:
step = next(gen) # first yielded Maybe
while True:
if not isinstance(step, Maybe):
raise TypeError("maybe_do: yielded non-Maybe value")
if step.is_nothing():
return NOTHING
value = cast(Some[T], step).value
step = gen.send(value) # send unwrapped value back in
except StopIteration as stop:
return Some(cast(U, stop.value))
return wrapper
# ---- Optional: async version for awaitables of Maybe ----
from typing import Awaitable
A = TypeVar("A")
R = TypeVar("R")
def maybe_do_async(fn: Callable[..., "async_generator[Any, R]"]):
"""
Decorate an *async* generator where each yield is an awaitable of Maybe.
Example:
@maybe_do_async
async def flow():
u = await (yield fetch_user(uid)) # fetch_user -> Awaitable[Maybe[User]]
p = await (yield fetch_profile(u)) # ...
return p
(You can also yield a raw Maybe; it won't be awaited.)
"""
async def wrapper(*args, **kwargs) -> "Maybe[R]":
agen = fn(*args, **kwargs)
try:
step = await agen.__anext__() # first yielded thing
while True:
# step may be Awaitable[Maybe[X]] or Maybe[X]
m = await step if hasattr(step, "__await__") else step
if not isinstance(m, Maybe):
raise TypeError("maybe_do_async: yielded non-Maybe value")
if m.is_nothing():
# drain/close the async generator
await agen.aclose()
return NOTHING
step = await agen.asend(cast(Some[A], m).value)
except StopAsyncIteration as stop:
return Some(cast(R, stop.value))
return wrapper
# ---------------------------
# Examples / Quick checks
# ---------------------------
if __name__ == "__main__":
# Construction
a = Maybe.some(10)
b = Maybe.nothing()
c = Maybe.from_nullable("hi")
d = Maybe.from_nullable(None)
print(a, b, c, d) # Some(10) Nothing() Some('hi') Nothing()
# map / bind
inc = lambda x: x + 1
half_if_even = lambda x: Some(x // 2) if x % 2 == 0 else NOTHING
print(a.map(inc)) # Some(11)
print(b.map(inc)) # Nothing()
print(a.bind(half_if_even)) # Some(5)
print(Some(3).bind(half_if_even)) # Nothing()
# ap / lift2
add = lambda x, y: x + y
print(lift2(add)(Some(2), Some(40))) # Some(42)
print(lift2(add)(Some(2), NOTHING)) # Nothing()
# filter
print(Some(8).filter(lambda x: x > 5)) # Some(8)
print(Some(1).filter(lambda x: x > 5)) # Nothing()
print(NOTHING.filter(lambda x: x > 5)) # Nothing()
# defaulting / combination
print(b.get_or_else(999)) # 999
print(a.get_or_else(999)) # 10
print(b.or_else(Some("alt"))) # Some('alt')
print(a.or_else(Some("alt"))) # Some(10)
# match
msg = a.match(some=lambda v: f"value={v}", nothing=lambda: "nope")
print(msg) # value=10
# iteration
print(list(Some("x"))) # ['x']
print(list(NOTHING)) # []
# sequence / traverse
print(sequence([Some(1), Some(2), Some(3)])) # Some([1,2,3])
print(sequence([Some(1), NOTHING, Some(3)])) # Nothing()
print(traverse([2, 4, 5], half_if_even)) # Nothing()
print(traverse([2, 4, 6], half_if_even)) # Some([1,2,3])
# Quick law sanity (not exhaustive/proofs)
# Functor identity: map(id) == self
identity = lambda x: x
assert a.map(identity) == a and b.map(identity) == b
# Functor composition: map(f∘g) == map(g).map(f)
f, g = (lambda x: x + 2), (lambda x: x * 3)
assert a.map(lambda x: f(g(x))) == a.map(g).map(f)
# Monad left identity: Some(x).bind(f) == f(x)
assert Some(4).bind(half_if_even) == half_if_even(4)
# Monad right identity: m.bind(Some) == m
assert a.bind(Some) == a and b.bind(Some) == b
# Associativity: m.bind(f).bind(g) == m.bind(lambda x: f(x).bind(g))
def f_m(x: int) -> Maybe[int]:
return Some(x + 1)
def g_m(x: int) -> Maybe[int]:
return Some(x * 2)
lhs = a.bind(f_m).bind(g_m)
rhs = a.bind(lambda x: f_m(x).bind(g_m))
assert lhs == rhs
print("All quick checks passed.")
# ---------------------------
# Demo (sync)
# ---------------------------
# Reuse your Maybe/Some/Nothing from earlier
def half_if_even(x: int) -> Maybe[int]:
return Some(x // 2) if x % 2 == 0 else NOTHING
@maybe_do
def pipeline(start: int) -> "Generator[Maybe[int], int, int]":
a = yield Some(start + 1) # unwrap Some -> a
b = yield half_if_even(a) # maybe halve if even
c = yield Maybe.from_predicate(b + 3, lambda n: n > 4) # guard
return a + b + c
print(pipeline(3)) # Some( (4)+(2)+(5) ) = Some(11)
print(pipeline(0)) # Nothing() if any step fails (here it still succeeds)
print(pipeline(1)) # start=1 -> a=2 -> b=1 (odd -> Nothing) => Nothing
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment