Skip to content

Instantly share code, notes, and snippets.

@rntz
Last active April 21, 2025 16:45
Show Gist options
  • Save rntz/9f07852440e6de1743bdf9f5b05167d6 to your computer and use it in GitHub Desktop.
Save rntz/9f07852440e6de1743bdf9f5b05167d6 to your computer and use it in GitHub Desktop.
lambda join implementations via minikanren-style search
module Kanren where
import Control.Applicative
import Control.Monad
import Data.Monoid
import Data.List (intercalate)
-- The microkanren search monad. The MonadPlus instance for this implements a
-- *complete* search strategy, even over infinite search spaces -- unlike eg the
-- List monad -- AS LONG AS you use `Later` to guard any potentially infinite
-- loops. `Later` acts as a signal to "switch branches" and explore a different
-- part of the search space. Insert it anywhere you could have recursion.
--
-- Experienced miniKanrenners (eg Will Byrd and especially Michael Ballantyne)
-- have some good intuition about where it's important to insert `Later` in
-- order to get good performance. Might be worth chatting with them.
data Search a = Fail | Cons a (Search a) | Later (Search a)
deriving Functor
instance Applicative Search where
pure x = Cons x Fail
(<*>) = ap
instance Monad Search where
Fail >>= _ = Fail
Later xs >>= f = Later (xs >>= f)
Cons x xs >>= f = f x <|> (xs >>= f)
instance Alternative Search where
empty = Fail
Fail <|> xs = xs
-- Note the swapping of xs, ys in the Later case. This is crucial for
-- completeness.
Later xs <|> ys = Later (ys <|> xs)
Cons x xs <|> ys = Cons x (xs <|> ys)
instance MonadPlus Search where --just uses the Alternative methods
-- Ok, now let's use this to implement lambda-join. The insight is that lambda
-- join's semantics are nondeterministic, so we're essentially searching the
-- space of evaluations of a program!
type Symbol = String
data Value = VBot
| VPair Value Value
| VSet [Value]
| VFun (Value -> Compute Value)
| VSym Symbol
type Compute a = Search (Result a)
data Result a = Value a | Bot | Top deriving Functor
instance Applicative Result where pure = Value; (<*>) = ap
instance Monad Result where
Bot >>= _ = Bot
Top >>= _ = Top
Value v >>= f = f v
-- Actually a semilattice.
instance Semigroup (Result Value) where
Bot <> y = y
x <> Bot = x
Top <> _ = Top
_ <> Top = Top
Value x <> Value y = merge x y
instance Monoid (Result Value) where mempty = Bot
merge :: Value -> Value -> Result Value
merge VBot y = Value y
merge x VBot = Value x -- this case MUST go before those that return Top!
merge (VPair l1 r1) (VPair l2 r2) = VPair <$> merge l1 l2 <*> merge r1 r2
merge VPair{} _ = Top
merge (VSet xs) (VSet ys) = Value $ VSet (xs ++ ys)
merge VSet{} _ = Top
merge (VFun f) (VFun g) = Value $ VFun $ \x -> (<>) <$> f x <*> g x
merge VFun{} _ = Top
merge (VSym s1) (VSym s2) | s1 == s2 = Value (VSym s1)
merge VSym{} _ = Top
-- Term evaluation
type Var = String
data Term = Bottom | Error | BotV
| Join Term Term
| Var Var | Lam Var Term | App Term Term
| Pair Term Term | LetPair Var Var Term Term
| Set [Term] | BigJoin Var Term Term
| Symbol Symbol | LetSymbol Symbol Term Term
| Prim Value --hack
deriving Show
type Env = [(Var, Value)]
-- Used to make sure our search doesn't loop unproductively.
delay :: Compute a -> Compute a
delay k = pure Bot <|> Later k
-- Invariant: (eval env t) produces at least one search result. In particular,
-- (eval env t /= Fail).
--
-- Why? Because any term can at least step to Bot. To enforce this, we use
-- `delay` in the `App` case. (If we were being fully faithful to the lambda
-- join paper's semantics, EVERY case of eval would start with `delay`.) I
-- believe only putting it in App is sound, but have low confidence.
--
-- We depend on this invariant in several places, eg, in the Set case of eval.
eval :: Env -> Term -> Compute Value
eval env (Prim v) = pure $ Value v
eval env Bottom = pure Bot
eval env Error = pure Top
eval env BotV = pure (Value VBot)
eval env (Join t1 t2) = (<>) <$> eval env t1 <*> eval env t2
eval env (Var x)
| Just v <- lookup x env = pure (Value v)
| otherwise = error $ "unbound variable: " ++ x
eval env (Lam x body) = pure . Value . VFun $ \v -> eval ((x,v):env) body
eval env (App tfun targ) = delay $ do
fun <- eval env tfun
case fun of
Top -> pure Top
Value (VFun f) -> do arg <- eval env targ
case arg of
Value a -> f a
x -> pure x
-- It's strange to me that applying a non-function does not error/Top, but
-- that seems to be the semantics in the paper.
_ -> pure Bot
eval env (Pair t1 t2) = do
r1 <- eval env t1
r2 <- eval env t2
pure $ VPair <$> r1 <*> r2
eval env (LetPair x y tpair tbody) = do
rpair <- eval env tpair
case rpair of
Top -> pure Top
Value (VPair vx vy) -> eval ((x,vx):(y,vy):env) tbody
_ -> pure Bot
eval env (Set terms) = do
results <- mapM (eval env) terms
pure $ if or [True | Top <- results]
then Top
else Value $ VSet [v | Value v <- results] --drop bottoms
eval env (BigJoin x tset tbody) = do
rset <- eval env tset
case rset of
Top -> pure Top
Value (VSet values) -> mconcat <$> sequence [eval ((x,v):env) tbody | v <- values]
_ -> pure Bot
eval env (Symbol sym) = pure $ Value $ VSym sym
eval env (LetSymbol sym tsym tbody) = do
rsym <- eval env tsym
case rsym of
Top -> pure Top
Value (VSym vsym) | sym == vsym -> eval env tbody
_ -> pure Bot
-- EXAMPLES
class Display a where repr :: Int -> a -> String
display :: Display a => Int -> a -> IO ()
display n = putStrLn . repr n
instance Show Value where
show VBot = "⊥ᵥ"
show (VPair x y) = "(" ++ show x ++ ", " ++ show y ++ ")"
show (VSet vs) = "{" ++ intercalate ", " (map show vs) ++ "}"
show VFun{} = "<fun>"
show (VSym s) = s
instance Show a => Show (Result a) where
show Bot = "⊥"
show Top = "⊤"
show (Value v) = show v
instance Show a => Display (Search a) where
repr 0 Later{} = "?"
repr n (Later x@Later{}) = "." ++ repr (n-1) x
repr n (Later x) = ". " ++ repr (n-1) x
repr n Fail = "nil"
repr n (Cons x Fail) = show x
repr n (Cons x y@Later{}) = show x ++ repr n y
repr n (Cons x xs) = show x ++ ", " ++ repr n xs
-- fixed point combinator
z = Lam "f" (App wub wub)
where wub = Lam "x" (App (Var "f") (Lam "y" (App (App (Var "x") (Var "x")) (Var "y"))))
botfun = App z (Lam "self" (Var "self"))
num :: Int -> Term
num n = Prim $ VSym $ show n
primAdd :: Value
primAdd = VFun $ \x -> pure $ Value $ VFun $ \y -> pure $ foo x y
where foo (VSym x) (VSym y) = Value $ VSym (show ((read x :: Int) + (read y :: Int)))
foo _ _ = Top
add = Prim primAdd
four = add `App` num 2 `App` num 2
plus2all :: Term
plus2all = Lam "xs" $ BigJoin "x" (Var "xs") $ Set [(add `App` Var "x") `App` num 2]
-- same but using z fixed point combinator instead of Fix
zevens :: Term
zevens = (z `App` Lam "evens" (Lam "_" body)) `App` num 0xbeef
where body = Set [num 0] `Join` (plus2all `App` (Var "evens" `App` num 0xdead))
module Kanren where
import Control.Applicative
import Control.Monad
import Data.Monoid
import Data.List (intercalate)
-- Inspired by kanren, but with two terminators: Bot and Top.
-- Bot is default -- any term can step to Bot.
-- Top means we encountered an error.
data Search a = Bot | Top | Cons a (Search a) | Later (Search a)
deriving Functor
instance Applicative Search where
pure x = Cons x Bot
(<*>) = ap
instance Monad Search where
Bot >>= _ = Bot
Top >>= _ = Top
Later xs >>= f = Later (xs >>= f)
Cons x xs >>= f = f x <|> (xs >>= f)
instance Alternative Search where
empty = Bot
Bot <|> xs = xs
Top <|> xs = Top
-- Note the swapping of xs, ys in the Later case. This is crucial for
-- completeness.
Later xs <|> ys = Later (ys <|> xs)
Cons x xs <|> ys = Cons x (xs <|> ys)
-- Ok, now let's use this to implement lambda-join. The insight is that lambda
-- join's semantics are nondeterministic, so we're essentially searching the
-- space of evaluations of a program!
type Symbol = String
data Value = VBot
| VPair Value Value
| VSet [Value]
| VFun (Value -> Search Value)
| VSym Symbol
-- The partial semilattice on values.
-- Produces at most a single result, otherwise Top.
merge :: Value -> Value -> Search Value
merge VBot y = pure y
merge x VBot = pure x -- this case MUST go before those that return Top!
merge (VPair l1 r1) (VPair l2 r2) = VPair <$> merge l1 l2 <*> merge r1 r2
merge VPair{} _ = Top
merge (VSet xs) (VSet ys) = pure $ VSet (xs ++ ys)
merge VSet{} _ = Top
merge (VFun f) (VFun g) = pure $ VFun $ \x -> f x <> g x
merge VFun{} _ = Top
merge (VSym s1) (VSym s2) | s1 == s2 = pure (VSym s1)
merge VSym{} _ = Top
-- The semilattice on computations.
instance Semigroup (Search Value) where
-- -- We must explicitly include xs and ys because otherwise `Bot` is not an
-- -- identity.
-- --
-- -- It doesn't suffice to just check for Bot because the same should hold for
-- -- `Later Bot`, `Later (Later Bot)`, etc.
-- xs <> ys = sboth <|> xs <|> ys
-- where sboth = do x <- xs; y <- ys; merge x y
-- Alternative, more performant, HOPEFULLY correct definition:
Top <> _ = Top
Bot <> ys = ys
-- Switch sides to be exhaustive.
Later xs <> ys = Later (ys <> xs)
-- If xs produces a value x, we could
-- 1. produce that element, x
-- 2. merge it with any elements of ys
-- 3. combine something else from xs with ys
Cons x xs <> ys = Cons x $ (merge x =<< ys) <|> (xs <> ys)
instance Monoid (Search Value) where mempty = Bot
-- Term evaluation
type Var = String
data Term = Bottom | Error | BotV
| Join Term Term
| Var Var | Lam Var Term | App Term Term
| Pair Term Term | LetPair Var Var Term Term
| Set [Term] | BigJoin Var Term Term
| Symbol Symbol | LetSymbol Symbol Term Term
| Prim Value --hack
deriving Show
type Env = [(Var, Value)]
eval :: Env -> Term -> Search Value
eval env (Prim v) = pure v
eval env Bottom = Bot
eval env Error = Top
eval env BotV = pure VBot
eval env (Join t1 t2) = eval env t1 <> eval env t2
eval env (Var x)
| Just v <- lookup x env = pure v
| otherwise = error $ "unbound variable: " ++ x
eval env (Lam x body) = pure . VFun $ \v -> eval ((x,v):env) body
eval env (App tfun targ) = Later $ do
fun <- eval env tfun
case fun of
VFun f -> f =<< eval env targ
-- It's strange to me that applying a non-function is not error/Top, but
-- that seems to be the semantics in the paper.
_ -> Bot
-- TODO: test this on each case of (eval env {t1,t2} = {Top,Bot}).
eval env (Pair t1 t2) = VPair <$> eval env t1 <*> eval env t2
eval env (LetPair x y tpair tbody) = do
vpair <- eval env tpair
case vpair of
VPair v1 v2 -> eval ((x,v1):(y,v2):env) tbody
_ -> Bot --again, strange to me that this isn't Top.
eval env (Set terms) = do
-- The normal behavior of the monad on Search is "strict", in that Bot
-- terminates the search: Bot >>= f = Bot. But we don't want that here; a Bot
-- will not contribute to the set, but it won't stop other values
--
-- The Nothing/Just here is to correctly handle
-- terms that evaluate to Bot and don't contribute to the set.
vals <- forM terms $ \t -> pure Nothing <|> fmap Just (eval env t)
pure $ VSet [v | Just v <- vals]
eval env (BigJoin x tset tbody) = do
vset <- eval env tset
case vset of
VSet vs -> mconcat [eval ((x,v):env) tbody | v <- vs]
_ -> Bot
eval env (Symbol sym) = pure $ VSym sym
eval env (LetSymbol sym tsym tbody) = do
rsym <- eval env tsym
case rsym of
VSym vsym | sym == vsym -> eval env tbody
_ -> Bot
-- EXAMPLES
class Display a where repr :: Int -> a -> String
display :: Display a => Int -> a -> IO ()
display n = putStrLn . repr n
instance Show Value where
show VBot = "⊥ᵥ"
show (VPair x y) = "(" ++ show x ++ ", " ++ show y ++ ")"
show (VSet vs) = "{" ++ intercalate ", " (map show vs) ++ "}"
show VFun{} = "<fun>"
show (VSym s) = s
instance Show a => Display (Search a) where
repr 0 Later{} = "?"
repr n (Later x@Later{}) = "." ++ repr (n-1) x
repr n (Later x) = ". " ++ repr (n-1) x
repr n Bot = "nil"
repr n Top = "!!"
repr n (Cons x Bot) = show x
repr n (Cons x y@Later{}) = show x ++ repr n y
repr n (Cons x xs) = show x ++ ", " ++ repr n xs
-- fixed point combinator
z = Lam "f" (App wub wub)
where wub = Lam "x" (App (Var "f") (Lam "y" (App (App (Var "x") (Var "x")) (Var "y"))))
botfun = App z (Lam "self" (Var "self"))
num :: Int -> Term
num n = Prim $ VSym $ show n
primAdd :: Value
primAdd = VFun $ \x -> pure $ VFun $ \y -> pure $ foo x y
where foo (VSym x) (VSym y) = VSym (show ((read x :: Int) + (read y :: Int)))
foo _ _ = error "adding non numbers"
add = Prim primAdd
four = add `App` num 2 `App` num 2
plus2all :: Term
plus2all = Lam "xs" $ BigJoin "x" (Var "xs") $ Set [(add `App` Var "x") `App` num 2]
-- same but using z fixed point combinator instead of Fix
zevens :: Term
zevens = (z `App` Lam "evens" (Lam "_" body)) `App` num 0xbeef
where body = Set [num 0] `Join` (plus2all `App` (Var "evens" `App` num 0xdead))
-- In this approach we "linearize" semilattice join (e1 ∨ e2) to avoid
-- combinatorial explosion. Instead of finding all values (v1 ∨ v2), such that
-- e1 ↦ v1 and e2 ↦ v2, we simply interleave evaluation of e1 and e2,
-- enumerating all values v such that _either_ e1 ↦ v or e2 ↦ v.
--
-- This, of course, is only sound for downstream operators that are _linear_,
-- that is, where f(x ∨ y) = f(x) ∨ f(y). But most of λ∨ _is_ linear in this
-- sense. And for those parts that aren't, we can insert an accumulation point,
-- where as we find new values we semilattice-join them into a "running
-- total"/"partial sum".
module KanrenMonotone where
import Control.Applicative
import Control.Monad
import Data.Monoid
import Data.List (intercalate)
-- Inspired by kanren, but with two terminators: Bot and Top.
-- Bot is default -- any term can step to Bot.
-- Top means we encountered an error.
data Search a = Bot | Top | Cons a (Search a) | Later (Search a)
deriving Functor
instance Applicative Search where
pure x = Cons x Bot
(<*>) = ap
instance Monad Search where
Bot >>= _ = Bot
Top >>= _ = Top
Later xs >>= f = Later (xs >>= f)
Cons x xs >>= f = f x <|> (xs >>= f)
instance Alternative Search where
empty = Bot
Bot <|> xs = xs
Top <|> xs = Top
-- Note the swapping of xs, ys in the Later case. This is crucial for
-- completeness.
Later xs <|> ys = Later (ys <|> xs)
Cons x xs <|> ys = Cons x (xs <|> ys)
-- Ok, now let's use this to implement lambda-join. The insight is that lambda
-- join's semantics are nondeterministic, so we're essentially searching the
-- space of evaluations of a program!
type Symbol = String
data Value = VBot
| VPair Value Value
| VSet [Value]
| VFun (Value -> Search Value)
| VSym Symbol
-- Term evaluation
type Var = String
data Term = Bottom | Error | BotV
| Join Term Term
| Var Var | Lam Var Term | App Term Term
| Pair Term Term | LetPair Var Var Term Term
| Set [Term] | BigJoin Var Term Term
| Symbol Symbol | LetSymbol Symbol Term Term
| Prim Value --hack
deriving Show
type Env = [(Var, Value)]
eval :: Env -> Term -> Search Value
eval env (Prim v) = pure v
eval env Bottom = Bot
eval env Error = Top
eval env BotV = pure VBot
eval env (Join t1 t2) = eval env t1 <|> eval env t2
eval env (Var x)
| Just v <- lookup x env = pure v
| otherwise = error $ "unbound variable: " ++ x
eval env (Lam x body) = pure . VFun $ \v -> eval ((x,v):env) body
eval env (App tfun targ) = Later $ do
fun <- eval env tfun
case fun of
VFun f -> f =<< eval env targ
-- It's strange to me that applying a non-function is not error/Top, but
-- that seems to be the semantics in the paper.
_ -> Bot
-- TODO: test this on each case of (eval env {t1,t2} = {Top,Bot}).
eval env (Pair t1 t2) = VPair <$> eval env t1 <*> eval env t2
eval env (LetPair x y tpair tbody) = do
vpair <- eval env tpair
case vpair of
VPair v1 v2 -> eval ((x,v1):(y,v2):env) tbody
_ -> Bot --again, strange to me that this isn't Top.
eval env (Set terms) = asum [fmap singleton (eval env t) | t <- terms]
where singleton x = VSet [x]
eval env (BigJoin x tset tbody) = do
vset <- eval env tset
case vset of
VSet vs -> asum [eval ((x,v):env) tbody | v <- vs]
_ -> Bot
eval env (Symbol sym) = pure $ VSym sym
-- NB. if symbols had a nontrivial order/semilattice join, we'd need an
-- `accumulate` (see below) here in case `tsym` produced multiple symbols.
eval env (LetSymbol sym tsym tbody) = do
rsym <- eval env tsym
case rsym of
VSym vsym | sym == vsym -> eval env tbody
_ -> Bot
-- Semilattice join. This code is currently unused, but it would be necessary to
-- handle nonlinear operators, ie. operators such that f(x ∨ y) ≠ f(x) ∨ f(y).
-- The partial semilattice on values.
-- Produces at most a single result, otherwise Top.
merge :: Value -> Value -> Search Value
merge VBot y = pure y
merge x VBot = pure x -- this case MUST go before those that return Top!
merge (VPair l1 r1) (VPair l2 r2) = VPair <$> merge l1 l2 <*> merge r1 r2
merge VPair{} _ = Top
merge (VSet xs) (VSet ys) = pure $ VSet (xs ++ ys)
merge VSet{} _ = Top
-- Note use of <|> as "semilattice join" of two computations.
merge (VFun f) (VFun g) = pure $ VFun $ \x -> f x <|> g x
merge VFun{} _ = Top
merge (VSym s1) (VSym s2) | s1 == s2 = pure (VSym s1)
merge VSym{} _ = Top
-- Turns a sequence of values into a sequence of increasing partial "sums",
-- where addition = merge. Emits one partial sum per "tick" (ie. between
-- Laters).
accumulate :: Search Value -> Search Value
accumulate Bot = Bot
accumulate Top = Top
accumulate (Later xs) = Later $ accumulate xs
accumulate (Cons x xs) = accum True x xs
where consIf True x xs = Cons x xs
consIf False _ xs = xs
-- we use a "dirty bit" to track whether we need to emit an updated
-- value before the next `Later`.
accum dirty sofar Bot = consIf dirty sofar Bot
accum dirty sofar Top = Top
-- end of tick; emit if dirty.
accum dirty sofar (Later xs) = consIf dirty sofar $ Later $ accum False sofar xs
accum dirty sofar (Cons x xs) =
-- optimized for `merge` producing at most one value
case merge sofar x of
Top -> Top
Cons sofar' Bot -> accum True sofar' xs
_ -> error "impossible!"
-- EXAMPLES
class Display a where repr :: Int -> a -> String
display :: Display a => Int -> a -> IO ()
display n = putStrLn . repr n
instance Show Value where
show VBot = "⊥ᵥ"
show (VPair x y) = "(" ++ show x ++ ", " ++ show y ++ ")"
show (VSet vs) = "{" ++ intercalate ", " (map show vs) ++ "}"
show VFun{} = "<fun>"
show (VSym s) = s
instance Show a => Display (Search a) where
repr 0 Later{} = "?"
repr n (Later x@Later{}) = "." ++ repr (n-1) x
repr n (Later x) = ". " ++ repr (n-1) x
repr n Bot = "nil"
repr n Top = "!!"
repr n (Cons x Bot) = show x
repr n (Cons x y@Later{}) = show x ++ repr n y
repr n (Cons x xs) = show x ++ ", " ++ repr n xs
-- fixed point combinator
z = Lam "f" (App wub wub)
where wub = Lam "x" (App (Var "f") (Lam "y" (App (App (Var "x") (Var "x")) (Var "y"))))
botfun = App z (Lam "self" (Var "self"))
num :: Int -> Term
num n = Prim $ VSym $ show n
primAdd :: Value
primAdd = VFun $ \x -> pure $ VFun $ \y -> pure $ foo x y
where foo (VSym x) (VSym y) = VSym (show ((read x :: Int) + (read y :: Int)))
foo _ _ = error "adding non numbers"
add = Prim primAdd
four = add `App` num 2 `App` num 2
plus2all :: Term
plus2all = Lam "xs" $ BigJoin "x" (Var "xs") $ Set [(add `App` Var "x") `App` num 2]
-- same but using z fixed point combinator instead of Fix
zevens :: Term
zevens = (z `App` Lam "evens" (Lam "_" body)) `App` num 0xbeef
where body = Set [num 0] `Join` (plus2all `App` (Var "evens" `App` num 0xdead))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment