Skip to content

Instantly share code, notes, and snippets.

@sunny-g
Forked from VictorTaelin/dps_sup_nodes.md
Created August 7, 2024 15:53
Show Gist options
  • Save sunny-g/84a4e9909b76f1bce238036db36c2493 to your computer and use it in GitHub Desktop.
Save sunny-g/84a4e9909b76f1bce238036db36c2493 to your computer and use it in GitHub Desktop.
Accelerating Discrete Program Search with SUP Nodes

Accelerating Discrete Program Search

I am investigating how to use Bend (a parallel language) to accelerate Symbolic AI; in special, Discrete Program Search. Basically, think of it as an alternative to LLMs, GPTs, NNs, that is also capable of generating code, but by entirely different means. This kind of approach was never scaled with mass compute before - it wasn't possible! - but Bend changes this. So, my idea was to do it, and see where it goes.

Now, while I was implementing some candidate algorithms on Bend, I realized that, rather than mass parallelism, I could use an entirely different mechanism to speed things up: SUP Nodes. Basically, it is a feature that Bend inherited from its underlying model ("Interaction Combinators") that, in simple terms, allows us to combine multiple functions into a single superposed one, and apply them all to an argument "at the same time". In short, it allows us to call N functions at a fraction of the expected cost. Or, in simple terms: why parallelize when we can share?

As you can already imagine, this could be extremely important for Symbolic AI algorithms like Discrete Program Search (DPS), one of the top solutions to ARC-AGI (a memorization-resistant reasoning benchmark on which LLMs struggle). That's because DPS works by enumerating a massive amount of candidate programs and trying each one on the test cases, until one succeeds. Obviously, this approach is dumb, exponential and computationally prohibitive, which is why ARC-AGI was created to begin with: to motivate better (NN-based) algorithms.

But is brute-force really that bad? While there is a massive amount of terms to explore, these terms aren't random; they're structured, redundant and highly repetitive. This suggests some mechanism to optimally share intermediate computations could greatly improve the search. And that's exactly what SUP Nodes do! This sounded so compelling that spent the last few weeks trying to come up with the right way to do it. A few days ago, I had a huge breakthrough, but there was still one last issue: how to collapse the superpositions back into a single result, without generating more work? Today, I just realized how, and all the pieces fell into place.

As the result, I can now search programs that would take trillions of interactions in less than a million interactions. To validate it, I implemented the exact same algorithms (a very simple enumerator) on Haskell, as well as I could. Using the Omega Monad, Haskell takes about 2.8s second to find a sample XOR-XNOR function, while HVM takes 0.0085s. So, as far as I can tell, this does seem like a speedup. I'm publishing below the Haskell code as a sanity check. Am I doing something wrong? Without changing the algorithm (i.e., keeping it a brute-force search), can we bring this time down?

Haskell Algorithm (slow, via Omega Monad)

-- A demo, minimal Program Search in Haskell

-- Given a test (input/output) pairs, it will find a function that passes it.
-- This file is for demo purposes, so, it is restricted to just simple, single
-- pass recursive functions. The idea is to use HVM superpositions to try many
-- functions "at once". Obviously, Haskell does not have them, so, we just use
-- the Omega Monad to convert to a list of functions, and try each separately.

import Control.Monad (forM_)

-- PRELUDE
----------

newtype Omega a = Omega { runOmega :: [a] }

instance Functor Omega where
  fmap f (Omega xs) = Omega (map f xs)

instance Applicative Omega where
  pure x                = Omega [x]
  Omega fs <*> Omega xs = Omega [f x | f <- fs, x <- xs]

instance Monad Omega where
  Omega xs >>= f = Omega $ diagonal $ map (\x -> runOmega (f x)) xs

diagonal :: [[a]] -> [a]
diagonal xs = concat (stripe xs) where
  stripe []             = []
  stripe ([]     : xss) = stripe xss
  stripe ((x:xs) : xss) = [x] : zipCons xs (stripe xss)
  zipCons []     ys     = ys
  zipCons xs     []     = map (:[]) xs
  zipCons (x:xs) (y:ys) = (x:y) : zipCons xs ys

-- ENUMERATOR
-------------

-- A bit-string
data Bin
  = O Bin
  | I Bin
  | E

-- A simple DSL for `Bin -> Bin` terms
data Term
  = MkO Term      -- emits the bit 0
  | MkI Term      -- emits the bit 1
  | Mat Term Term -- pattern-matches on the argument
  | Rec           -- recurses on the argument
  | Ret           -- returns the argument
  | Sup Term Term -- a superposition of two functions

-- Checks if two Bins are equal
bin_eq :: Bin -> Bin -> Bool
bin_eq (O xs) (O ys) = bin_eq xs ys
bin_eq (I xs) (I ys) = bin_eq xs ys
bin_eq E      E      = True
bin_eq _      _      = False

-- Stringifies a Bin
bin_show :: Bin -> String
bin_show (O xs) = "O" ++ bin_show xs
bin_show (I xs) = "I" ++ bin_show xs
bin_show E      = "E"

-- Checks if two term are equal
term_eq :: Term -> Term -> Bool
term_eq (Mat l0 r0) (Mat l1 r1) = term_eq l0 l1 && term_eq r0 r1
term_eq (MkO t0)    (MkO t1)    = term_eq t0 t1
term_eq (MkI t0)    (MkI t1)    = term_eq t0 t1
term_eq Rec         Rec         = True
term_eq Ret         Ret         = True
term_eq _           _           = False

-- Stringifies a term
term_show :: Term -> String
term_show (MkO t)    = "(O " ++ term_show t ++ ")"
term_show (MkI t)    = "(I " ++ term_show t ++ ")"
term_show (Mat l r)  = "{O:" ++ term_show l ++ "|I:" ++ term_show r ++ "}"
term_show (Sup a b)  = "{" ++ term_show a ++ "|" ++ term_show b ++ "}"
term_show Rec        = "@"
term_show Ret        = "*"

-- Enumerates all terms
enum :: Bool -> Term
enum s = (if s then Sup Rec else id) $ Sup Ret $ Sup (intr s) (elim s) where
  intr s = Sup (MkO (enum s)) (MkI (enum s))
  elim s = Mat (enum True) (enum True)

-- Converts a Term into a native function
make :: Term -> (Bin -> Bin) -> Bin -> Bin
make Ret       _ x = x
make Rec       f x = f x
make (MkO trm) f x = O (make trm f x)
make (MkI trm) f x = I (make trm f x)
make (Mat l r) f x = case x of
  O xs -> make l f xs
  I xs -> make r f xs
  E    -> E

-- Finds a program that satisfies a test
search :: Int -> (Term -> Bool) -> [Term] -> IO ()
search n test (tm:tms) = do
  if test tm then
    putStrLn $ "FOUND " ++ term_show tm ++ " (after " ++ show n ++ " guesses)"
  else
    search (n+1) test tms

-- Collapses a superposed term to a list of terms, diagonalizing
collapse :: Term -> Omega Term
collapse (MkO t) = do
  t' <- collapse t
  return $ MkO t'
collapse (MkI t) = do
  t' <- collapse t
  return $ MkI t'
collapse (Mat l r) = do
  l' <- collapse l
  r' <- collapse r
  return $ Mat l' r'
collapse (Sup a b) =
  let a' = runOmega (collapse a) in
  let b' = runOmega (collapse b) in
  Omega (diagonal [a',b'])
collapse Rec = return Rec
collapse Ret = return Ret

-- Some test cases:
-- ----------------

test_not :: Term -> Bool
test_not tm = e0 && e1 where
  fn = make tm fn
  x0 = (O (I (O (O (O (I (O (O E))))))))
  y0 = (I (O (I (I (I (O (I (I E))))))))
  e0 = (bin_eq (fn x0) y0)
  x1 = (I (I (I (O (O (I (I (I E))))))))
  y1 = (O (O (O (I (I (O (O (O E))))))))
  e1 = (bin_eq (fn x1) y1)

test_inc :: Term -> Bool
test_inc tm = e0 && e1 where
  fn = make tm fn
  x0 = (O (I (O (O (O (I (O (O E))))))))
  y0 = (I (I (O (O (O (I (O (O E))))))))
  e0 = (bin_eq (fn x0) y0)
  x1 = (I (I (I (O (O (I (I (I E))))))))
  y1 = (O (O (O (I (O (I (I (I E))))))))
  e1 = (bin_eq (fn x1) y1)

test_mix :: Term -> Bool
test_mix tm = e0 && e1 where
  fn = make tm fn
  x0 = (O (I (O (O (O (I (O (O E))))))))
  y0 = (I (O (I (I (I (O (I (O (I (O (I (I (I (O (I (O E))))))))))))))))
  e0 = (bin_eq (fn x0) y0)
  x1 = (I (I (I (O (O (I (I (I E))))))))
  y1 = (I (I (I (I (I (I (I (O (I (O (I (I (I (I (I (I E))))))))))))))))
  e1 = (bin_eq (fn x1) y1)

test_xors :: Term -> Bool
test_xors tm = e0 && e1 where
  fn = make tm fn
  x0 = (I (I (O (O (O (I (O (O E))))))))
  y0 = (I (I (O (I E))))
  e0 = (bin_eq (fn x0) y0)
  x1 = (I (O (O (I (I (I (O (I E))))))))
  y1 = (O (O (I (O E))))
  e1 = (bin_eq (fn x1) y1)

test_xor_xnor :: Term -> Bool
test_xor_xnor tm = e0 && e1 where
  fn = make tm fn
  x0 = (I (O (O (I (I (O (I (I (I (O E))))))))))
  y0 = (I (O (I (O (I (O (O (I (I (O E))))))))))
  e0 = (bin_eq (fn x0) y0)
  x1 = (O (I (O (O (O (I (O (I (O (O E))))))))))
  y1 = (I (O (O (I (I (O (I (O (O (I E)))))))))) 
  e1 = (bin_eq (fn x1) y1)

main :: IO ()
main = search 0 test_xor_xnor $ runOmega $ collapse $ enum False

HVM Algorithm (fast, with SUP Nodes)

fff7d700df6dd39718a6e0e12252f611e7da14fc5ff8d363dc36ee6b64f9786a
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment