I am investigating how to use Bend (a parallel language) to accelerate Symbolic AI; in special, Discrete Program Search. Basically, think of it as an alternative to LLMs, GPTs, NNs, that is also capable of generating code, but by entirely different means. This kind of approach was never scaled with mass compute before - it wasn't possible! - but Bend changes this. So, my idea was to do it, and see where it goes.
Now, while I was implementing some candidate algorithms on Bend, I realized that, rather than mass parallelism, I could use an entirely different mechanism to speed things up: SUP Nodes. Basically, it is a feature that Bend inherited from its underlying model ("Interaction Combinators") that, in simple terms, allows us to combine multiple functions into a single superposed one, and apply them all to an argument "at the same time". In short, it allows us to call N functions at a fraction of the expected cost. Or, in simple terms: why parallelize when we can share?
As you can already imagine, this could be extremely important for Symbolic AI algorithms like Discrete Program Search (DPS), one of the top solutions to ARC-AGI (a memorization-resistant reasoning benchmark on which LLMs struggle). That's because DPS works by enumerating a massive amount of candidate programs and trying each one on the test cases, until one succeeds. Obviously, this approach is dumb, exponential and computationally prohibitive, which is why ARC-AGI was created to begin with: to motivate better (NN-based) algorithms.
But is brute-force really that bad? While there is a massive amount of terms to explore, these terms aren't random; they're structured, redundant and highly repetitive. This suggests some mechanism to optimally share intermediate computations could greatly improve the search. And that's exactly what SUP Nodes do! This sounded so compelling that spent the last few weeks trying to come up with the right way to do it. A few days ago, I had a huge breakthrough, but there was still one last issue: how to collapse the superpositions back into a single result, without generating more work? Today, I just realized how, and all the pieces fell into place.
As the result, I can now search programs that would take trillions of interactions in less than a million interactions. To validate it, I implemented the exact same algorithms (a very simple enumerator) on Haskell, as well as I could. Using the Omega Monad, Haskell takes about 2.8s second to find a sample XOR-XNOR function, while HVM takes 0.0085s. So, as far as I can tell, this does seem like a speedup. I'm publishing below the Haskell code as a sanity check. Am I doing something wrong? Without changing the algorithm (i.e., keeping it a brute-force search), can we bring this time down?
-- A demo, minimal Program Search in Haskell
-- Given a test (input/output) pairs, it will find a function that passes it.
-- This file is for demo purposes, so, it is restricted to just simple, single
-- pass recursive functions. The idea is to use HVM superpositions to try many
-- functions "at once". Obviously, Haskell does not have them, so, we just use
-- the Omega Monad to convert to a list of functions, and try each separately.
import Control.Monad (forM_)
-- PRELUDE
----------
newtype Omega a = Omega { runOmega :: [a] }
instance Functor Omega where
fmap f (Omega xs) = Omega (map f xs)
instance Applicative Omega where
pure x = Omega [x]
Omega fs <*> Omega xs = Omega [f x | f <- fs, x <- xs]
instance Monad Omega where
Omega xs >>= f = Omega $ diagonal $ map (\x -> runOmega (f x)) xs
diagonal :: [[a]] -> [a]
diagonal xs = concat (stripe xs) where
stripe [] = []
stripe ([] : xss) = stripe xss
stripe ((x:xs) : xss) = [x] : zipCons xs (stripe xss)
zipCons [] ys = ys
zipCons xs [] = map (:[]) xs
zipCons (x:xs) (y:ys) = (x:y) : zipCons xs ys
-- ENUMERATOR
-------------
-- A bit-string
data Bin
= O Bin
| I Bin
| E
-- A simple DSL for `Bin -> Bin` terms
data Term
= MkO Term -- emits the bit 0
| MkI Term -- emits the bit 1
| Mat Term Term -- pattern-matches on the argument
| Rec -- recurses on the argument
| Ret -- returns the argument
| Sup Term Term -- a superposition of two functions
-- Checks if two Bins are equal
bin_eq :: Bin -> Bin -> Bool
bin_eq (O xs) (O ys) = bin_eq xs ys
bin_eq (I xs) (I ys) = bin_eq xs ys
bin_eq E E = True
bin_eq _ _ = False
-- Stringifies a Bin
bin_show :: Bin -> String
bin_show (O xs) = "O" ++ bin_show xs
bin_show (I xs) = "I" ++ bin_show xs
bin_show E = "E"
-- Checks if two term are equal
term_eq :: Term -> Term -> Bool
term_eq (Mat l0 r0) (Mat l1 r1) = term_eq l0 l1 && term_eq r0 r1
term_eq (MkO t0) (MkO t1) = term_eq t0 t1
term_eq (MkI t0) (MkI t1) = term_eq t0 t1
term_eq Rec Rec = True
term_eq Ret Ret = True
term_eq _ _ = False
-- Stringifies a term
term_show :: Term -> String
term_show (MkO t) = "(O " ++ term_show t ++ ")"
term_show (MkI t) = "(I " ++ term_show t ++ ")"
term_show (Mat l r) = "{O:" ++ term_show l ++ "|I:" ++ term_show r ++ "}"
term_show (Sup a b) = "{" ++ term_show a ++ "|" ++ term_show b ++ "}"
term_show Rec = "@"
term_show Ret = "*"
-- Enumerates all terms
enum :: Bool -> Term
enum s = (if s then Sup Rec else id) $ Sup Ret $ Sup (intr s) (elim s) where
intr s = Sup (MkO (enum s)) (MkI (enum s))
elim s = Mat (enum True) (enum True)
-- Converts a Term into a native function
make :: Term -> (Bin -> Bin) -> Bin -> Bin
make Ret _ x = x
make Rec f x = f x
make (MkO trm) f x = O (make trm f x)
make (MkI trm) f x = I (make trm f x)
make (Mat l r) f x = case x of
O xs -> make l f xs
I xs -> make r f xs
E -> E
-- Finds a program that satisfies a test
search :: Int -> (Term -> Bool) -> [Term] -> IO ()
search n test (tm:tms) = do
if test tm then
putStrLn $ "FOUND " ++ term_show tm ++ " (after " ++ show n ++ " guesses)"
else
search (n+1) test tms
-- Collapses a superposed term to a list of terms, diagonalizing
collapse :: Term -> Omega Term
collapse (MkO t) = do
t' <- collapse t
return $ MkO t'
collapse (MkI t) = do
t' <- collapse t
return $ MkI t'
collapse (Mat l r) = do
l' <- collapse l
r' <- collapse r
return $ Mat l' r'
collapse (Sup a b) =
let a' = runOmega (collapse a) in
let b' = runOmega (collapse b) in
Omega (diagonal [a',b'])
collapse Rec = return Rec
collapse Ret = return Ret
-- Some test cases:
-- ----------------
test_not :: Term -> Bool
test_not tm = e0 && e1 where
fn = make tm fn
x0 = (O (I (O (O (O (I (O (O E))))))))
y0 = (I (O (I (I (I (O (I (I E))))))))
e0 = (bin_eq (fn x0) y0)
x1 = (I (I (I (O (O (I (I (I E))))))))
y1 = (O (O (O (I (I (O (O (O E))))))))
e1 = (bin_eq (fn x1) y1)
test_inc :: Term -> Bool
test_inc tm = e0 && e1 where
fn = make tm fn
x0 = (O (I (O (O (O (I (O (O E))))))))
y0 = (I (I (O (O (O (I (O (O E))))))))
e0 = (bin_eq (fn x0) y0)
x1 = (I (I (I (O (O (I (I (I E))))))))
y1 = (O (O (O (I (O (I (I (I E))))))))
e1 = (bin_eq (fn x1) y1)
test_mix :: Term -> Bool
test_mix tm = e0 && e1 where
fn = make tm fn
x0 = (O (I (O (O (O (I (O (O E))))))))
y0 = (I (O (I (I (I (O (I (O (I (O (I (I (I (O (I (O E))))))))))))))))
e0 = (bin_eq (fn x0) y0)
x1 = (I (I (I (O (O (I (I (I E))))))))
y1 = (I (I (I (I (I (I (I (O (I (O (I (I (I (I (I (I E))))))))))))))))
e1 = (bin_eq (fn x1) y1)
test_xors :: Term -> Bool
test_xors tm = e0 && e1 where
fn = make tm fn
x0 = (I (I (O (O (O (I (O (O E))))))))
y0 = (I (I (O (I E))))
e0 = (bin_eq (fn x0) y0)
x1 = (I (O (O (I (I (I (O (I E))))))))
y1 = (O (O (I (O E))))
e1 = (bin_eq (fn x1) y1)
test_xor_xnor :: Term -> Bool
test_xor_xnor tm = e0 && e1 where
fn = make tm fn
x0 = (I (O (O (I (I (O (I (I (I (O E))))))))))
y0 = (I (O (I (O (I (O (O (I (I (O E))))))))))
e0 = (bin_eq (fn x0) y0)
x1 = (O (I (O (O (O (I (O (I (O (O E))))))))))
y1 = (I (O (O (I (I (O (I (O (O (I E))))))))))
e1 = (bin_eq (fn x1) y1)
main :: IO ()
main = search 0 test_xor_xnor $ runOmega $ collapse $ enum False
NOTE: to run this, you must use the dup_labels
branch of HVM1. We'll soon adapt Bend to work for this too, but it is currently lacking lazy mode, 64-bit dup labels and SUP-node quoting.
// HVM Prelude
// -----------
(Fix f) = (f (Fix f))
(U60.if 0 t f) = f
(U60.if x t f) = t
(U60.show n) = (U60.show.go n "")
(U60.show.go n x) = (λx(U60.if (> n 9) (U60.show.go (/ n 10) x) x) (String.cons (+ 48 (% n 10)) x))
(U60.seq 0 cont) = (cont 0)
(U60.seq n cont) = (cont n)
(And 0 x) = 0
(And 1 x) = x
(Or 0 x) = x
(Or 1 x) = 1
(List.get (List.nil) _) = (Err "out-of-bounds")
(List.get (List.cons x xs) 0) = x
(List.get (List.cons x xs) n) = (List.get xs (- n 1))
(List.map f List.nil) = List.nil
(List.map f (List.cons x xs)) = (List.cons (f x) (List.map f xs))
(List.imap f List.nil) = List.nil
(List.imap f (List.cons x xs)) = (List.cons (f 0 x) (List.imap λiλx(f (+ i 1) x) xs))
(List.concat (List.nil) ys) = ys
(List.concat (List.cons x xs) ys) = (List.cons x (List.concat xs ys))
(List.flatten List.nil) = List.nil
(List.flatten (List.cons x xs)) = (List.concat x (List.flatten xs))
(List.length List.nil) = 0
(List.length (List.cons x xs)) = (+ 1 (List.length xs))
(List.take 0 xs) = List.nil
(List.take n List.nil) = List.nil
(List.take n (List.cons x xs)) = (List.cons x (List.take (- n 1) xs))
(List.head (List.cons x xs)) = x
(List.tail (List.cons x xs)) = xs
(List.push x List.nil) = (List.cons x List.nil)
(List.push x (List.cons y ys)) = (List.cons y (List.push x ys))
(List.diagonal xs) = (List.flatten (List.stripe xs))
(List.stripe List.nil) = []
(List.stripe (List.cons List.nil xss)) = (List.stripe xss)
(List.stripe (List.cons (List.cons x xs) xss)) = (List.cons [x] (List.zip_cons xs (List.stripe xss)))
(List.zip_cons [] ys) = ys
(List.zip_cons xs []) = (List.map λk(List.cons k []) xs)
(List.zip_cons (List.cons x xs) (List.cons y ys)) = (List.cons (List.cons x y) (List.zip_cons xs ys))
(Omega.pure x) = [x]
(Omega.bind xs f) = (List.diagonal (List.map f xs))
(String.concat String.nil ys) = ys
(String.concat (String.cons x xs) ys) = (String.cons x (String.concat xs ys))
(String.join List.nil) = String.nil
(String.join (List.cons x xs)) = (String.concat x (String.join xs))
(String.seq (String.cons x xs) cont) = (U60.seq x λx(String.seq xs λxs(cont (String.cons x xs))))
(String.seq String.nil cont) = (cont String.nil)
(String.eq String.nil String.nil) = 1
(String.eq (String.cons x xs) (String.cons y ys)) = (And (== x y) (String.eq xs ys))
(String.eq xs ys) = 0
(String.take 0 xs) = String.nil
(String.take n String.nil) = String.nil
(String.take n (String.cons x xs)) = (String.cons x (String.take (- n 1) xs))
(Tup2.match (Tup2.new fst snd) fn) = (fn fst snd)
(Tup2.fst (Tup2.new fst snd)) = fst
(Tup2.snd (Tup2.new fst snd)) = snd
(Join None b) = b
(Join (Some x) b) = (Some x)
(Print [] value) = value
(Print msg value) = (String.seq (String.join msg) λstr(HVM.log str value))
// Priority Queue
// data PQ = Empty | Node U60 U60 PQ PQ
// PQ.new: Create a new empty Priority Queue
(PQ.new) = Empty
// PQ.put: Add a new (key, val) pair to the Priority Queue
(PQ.put key val Empty) = (Node key val Empty Empty)
(PQ.put key val (Node k v lft rgt)) = (PQ.put.aux (< key k) key val k v lft rgt)
(PQ.put.aux 1 key val k v lft rgt) = (Node key val (Node k v lft rgt) Empty)
(PQ.put.aux 0 key val k v lft rgt) = (Node k v (PQ.put key val lft) rgt)
// PQ.get: Get the smallest element and return it with the updated queue
(PQ.get Empty) = (HVM.LOG ERR 0)
(PQ.get (Node k v lft rgt)) = λs(s k v (PQ.merge lft rgt))
// Helper function to merge two priority queues
(PQ.merge Empty rgt) = rgt
(PQ.merge lft Empty) = lft
(PQ.merge (Node k1 v1 l1 r1) (Node k2 v2 l2 r2)) = (PQ.merge.aux (< k1 k2) k1 v1 l1 r1 k2 v2 l2 r2)
(PQ.merge.aux 1 k1 v1 l1 r1 k2 v2 l2 r2) = (Node k1 v1 (PQ.merge r1 (Node k2 v2 l2 r2)) l1)
(PQ.merge.aux 0 k1 v1 l1 r1 k2 v2 l2 r2) = (Node k2 v2 (PQ.merge (Node k1 v1 l1 r1) r2) l2)
// Collapser
(Collapse (HVM.SUP k a b) pq) = (Collapse None (PQ.put k a (PQ.put k b pq)))
(Collapse (Some x) pq) = x
(Collapse None pq) = ((PQ.get pq) λkλxλpq(Collapse x pq))
// Bin Enumerator
// --------------
(O xs) = λo λi λe (o xs)
(I xs) = λo λi λe (i xs)
E = λo λi λe e
(Bin.eq xs ys) = (xs
λxsp λys (ys λysp(Bin.eq xsp ysp) λysp(0) 0)
λxsp λys (ys λysp(0) λysp(Bin.eq xsp ysp) 0)
λys (ys λysp(0) λysp(0) 1)
ys)
(Term.eq (Mat l0 r0) (Mat l1 r1)) = (And (Term.eq l0 l1) (Term.eq r0 r1))
(Term.eq (MkO t0) (MkO t1)) = (Term.eq t0 t1)
(Term.eq (MkI t0) (MkI t1)) = (Term.eq t0 t1)
(Term.eq Rec Rec) = 1
(Term.eq Ret Ret) = 1
(Term.eq _ _) = 0
Zero = (O Zero)
Neg1 = (I Neg1)
(L0 x) = (+ (* x 2) 0)
(L1 x) = (+ (* x 2) 1)
(ENUM lab s) =
let lA = (+ (* lab 4) 0)
let lB = (+ (* lab 4) 1)
let lC = (+ (* lab 4) 2)
let rc = (U60.if s λx(HVM.SUP lB Rec x) λx(x))
let rt = λx(HVM.SUP lC Ret x)
(rt (rc (HVM.SUP lA
(INTR (L0 lab) s)
(ELIM (L1 lab) s))))
(INTR lab s) =
let lA = (+ (* lab 4) 3)
(HVM.SUP lA
(MkO (ENUM (L0 lab) s))
(MkI (ENUM (L1 lab) s)))
(ELIM lab s) =
(Mat (ENUM (L0 lab) 1)
(ENUM (L1 lab) 1))
(Make Ret ) = λfλx(x)
(Make Rec ) = λfλx(f x)
(Make (MkO trm)) = λfλx(O ((Make trm) f x))
(Make (MkI trm)) = λfλx(I ((Make trm) f x))
(Make (Mat l r)) = λfλx(x λx((Make l) f x) λx((Make r) f x) (E))
(Bin.show xs) = (xs λxs(String.join ["O" (Bin.show xs)]) λxs(String.join ["I" (Bin.show xs)]) "E")
(Bin.view xs) = (xs λxs(B0 (Bin.view xs)) λxs(B1 (Bin.view xs)) BE)
(COL (HVM.SUP k a b)) = (Join (COL a) (COL b))
(COL x) = x
(Flat (HVM.SUP k a b)) = (List.diagonal [(Flat a) (Flat b)])
(Flat Ret) = (Omega.pure Ret)
(Flat Rec) = (Omega.pure Rec)
(Flat (MkO trm)) = (Omega.bind (Flat trm) λtrm(Omega.pure (MkO trm)))
(Flat (MkI trm)) = (Omega.bind (Flat trm) λtrm(Omega.pure (MkI trm)))
(Flat (Mat l r)) = (Omega.bind (Flat l) λl(Omega.bind (Flat r)λr(Omega.pure (Mat l r))))
// TODO: implement a Term.show function
// Term.show function implementation
(Term.show Ret) = "Ret"
(Term.show Rec) = "Rec"
(Term.show (MkO term)) = (String.join ["(O " (Term.show term) ")"])
(Term.show (MkI term)) = (String.join ["(I " (Term.show term) ")"])
(Term.show (Mat left right)) = (String.join ["{" (Term.show left) "|" (Term.show right) "}"])
(Test_same g cd fn) =
let x0 = (I (O (O (I (I (O (I (I (I (O E))))))))))
let y0 = (g x0)
let e0 = (Bin.eq (fn x0) y0)
let x1 = (O (I (O (O (O (I (O (I (O (O E))))))))))
let y1 = (g x1)
let e1 = (Bin.eq (fn x1) y1)
(U60.if (And e0 e1) (Some cd) None)
(Test_not cd fn) =
let x0 = (O (I (O (O (O (I (O (O E))))))))
let y0 = (I (O (I (I (I (O (I (I E))))))))
let e0 = (Bin.eq (fn x0) y0)
let x1 = (I (I (I (O (O (I (I (I E))))))))
let y1 = (O (O (O (I (I (O (O (O E))))))))
let e1 = (Bin.eq (fn x1) y1)
(U60.if (And e0 e1) (Some cd) None)
(Test_inc cd fn) =
let x0 = (O (I (O (O (O (I (O (O E))))))))
let y0 = (I (I (O (O (O (I (O (O E))))))))
let e0 = (Bin.eq (fn x0) y0)
let x1 = (I (I (I (O (O (I (I (I E))))))))
let y1 = (O (O (O (I (O (I (I (I E))))))))
let e1 = (Bin.eq (fn x1) y1)
(U60.if (And e0 e1) (Some cd) None)
(Test_mix cd fn) =
let x0 = (O (I (O (O (O (I (O (O E))))))))
let y0 = (I (O (I (I (I (O (I (O (I (O (I (I (I (O (I (O E))))))))))))))))
let e0 = (Bin.eq (fn x0) y0)
let x1 = (I (I (I (O (O (I (I (I E))))))))
let y1 = (I (I (I (I (I (I (I (O (I (O (I (I (I (I (I (I E))))))))))))))))
let e1 = (Bin.eq (fn x1) y1)
(U60.if (And e0 e1) (Some cd) None)
(Test_xors cd fn) =
let x0 = (I (I (O (O (O (I (O (O E))))))))
let y0 = (I (I (O (I E))))
let e0 = (Bin.eq (fn x0) y0)
let x1 = (I (O (O (I (I (I (O (I E))))))))
let y1 = (O (O (I (O E))))
let e1 = (Bin.eq (fn x1) y1)
(U60.if (And e0 e1) (Some cd) None)
(Test_xor_xnor cd fn) =
let x0 = (I (O (O (I (I (O (I (I (I (O E))))))))))
let y0 = (I (O (I (O (I (O (O (I (I (O E))))))))))
let e0 = (Bin.eq (fn x0) y0)
let x1 = (O (I (O (O (O (I (O (I (O (O E))))))))))
let y1 = (I (O (O (I (I (O (I (O (O (I E))))))))))
let e1 = (Bin.eq (fn x1) y1)
(U60.if (And e0 e1) (Some cd) None)
Main =
//let term = (Mat (Mat (MkO (MkI (Rec))) (MkI (MkO (Rec)))) (Mat (MkI (MkO (Rec))) (MkO (MkI (Rec)))))
//let bits = (O (I (O (O (O (I (O (I (O (O E))))))))))
//let func = (Fix (Make term))
//(Bin.show (func bits))
//(Test_xor_xnor term (Fix (Make term)))
//let term = (Mat (Mat (MkI (MkO (Rec))) (MkO (MkI (Rec)))) (Mat (MkO (MkI (Rec))) (MkI (MkO (Rec)))))
//let func = (Fix (Make term))
let terms = (ENUM 1 0)
let funcs = (Fix (Make terms))
let found = (Test_xor_xnor terms funcs)
(Collapse found PQ.new)
Edit: HVM source published. Announcement on https://x.com/VictorTaelin/status/1829143659440144493
Is this a git commit somewhere? I'm not sure how to read this