Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Select an option

  • Save VictorTaelin/6c7629259508118d385114b873a07358 to your computer and use it in GitHub Desktop.

Select an option

Save VictorTaelin/6c7629259508118d385114b873a07358 to your computer and use it in GitHub Desktop.
challenge: optimize a reference interaction calculus implementation
{-# LANGUAGE MultilineStrings #-}
import Data.IORef
import Debug.Trace
import System.IO.Unsafe
import Text.ParserCombinators.ReadP
import qualified Data.Map as M
type Lab = String
type Name = String
type Map a = M.Map String a
-- dp0 ::= x₀
-- dp1 ::= x₁
-- era ::= &{}
-- sup ::= &L{a,b}
-- dup ::= !x&L=v;t
-- var ::= x
-- lam ::= λx.f
-- app ::= (f x)
-- nam ::= name
-- dry ::= (name arg)
data Term
= Nam Name
| Var Name
| Dp0 Name
| Dp1 Name
| Era
| Sup Lab Term Term
| Dup Name Lab Term Term
| Lam Name Term
| Abs Name Term
| Dry Term Term
| App Term Term
instance Show Term where
show (Nam k) = k
show (Dry f x) = "(" ++ show f ++ " " ++ show x ++ ")"
show (Var k) = k
show (Dp0 k) = k ++ "₀"
show (Dp1 k) = k ++ "₁"
show Era = "&{}"
show (Sup l a b) = "&" ++ l ++ "{" ++ show a ++ "," ++ show b ++ "}"
show (Dup k l v t) = "!" ++ k ++ "&" ++ l ++ "=" ++ show v ++ ";" ++ show t
show (Lam k f) = "λ" ++ k ++ "." ++ show f
show (Abs k f) = "λ" ++ k ++ "." ++ show f
show (App f x) = "(" ++ show f ++ " " ++ show x ++ ")"
-- Environment
-- ===========
data Env = Env
{ inters :: Int
, var_new :: Int
, dup_new :: Int
, var_map :: Map Term
, dp0_map :: Map Term
, dp1_map :: Map Term
, dup_map :: Map (Lab,Term)
} deriving Show
env :: Env
env = Env 0 0 0 M.empty M.empty M.empty M.empty
nameFrom :: String -> Int -> String
nameFrom chars n = build n "" where
base = length chars
build k acc | k <= 0 = acc
| otherwise = let (q,r) = (k - 1) `divMod` base in build q (chars !! r : acc)
fresh :: (Env -> Int) -> (Env -> Int -> Env) -> String -> Env -> (Env, String)
fresh get set chars s =
let next = get s + 1
in (set s next, "$" ++ nameFrom chars next)
fresh_var = fresh var_new (\s n -> s { var_new = n }) ['a'..'z']
fresh_dup = fresh dup_new (\s n -> s { dup_new = n }) ['A'..'Z']
subst :: (Env -> Map a) -> (Env -> Map a -> Env) -> Env -> String -> a -> Env
subst get set s k v = set s (M.insert k v (get s))
subst_var = subst var_map (\s m -> s { var_map = m })
subst_dp0 = subst dp0_map (\s m -> s { dp0_map = m })
subst_dp1 = subst dp1_map (\s m -> s { dp1_map = m })
delay_dup = subst dup_map (\s m -> s { dup_map = m })
taker :: (Env -> Map a) -> (Env -> Map a -> Env) -> Env -> String -> (Maybe a, Env)
taker get set s k = let (mt, m) = M.updateLookupWithKey (\_ _ -> Nothing) k (get s) in (mt, set s m)
take_var = taker var_map (\s m -> s { var_map = m })
take_dp0 = taker dp0_map (\s m -> s { dp0_map = m })
take_dp1 = taker dp1_map (\s m -> s { dp1_map = m })
take_dup = taker dup_map (\s m -> s { dup_map = m })
inc_inters :: Env -> Env
inc_inters s = s { inters = inters s + 1 }
-- Parsing
-- =======
lexeme :: ReadP a -> ReadP a
lexeme p = skipSpaces *> p
name :: ReadP String
name = lexeme parse_nam
parse_term :: ReadP Term
parse_term = lexeme $ choice
[ parse_lam
, parse_dup
, parse_app
, parse_sup
, parse_era
, parse_var
]
parse_app :: ReadP Term
parse_app = do
lexeme (char '(')
ts <- many1 parse_term
lexeme (char ')')
case ts of
(t:rest) -> return (Prelude.foldl App t rest)
_ -> pfail
parse_lam :: ReadP Term
parse_lam = do
lexeme (choice [char 'λ', char '\\'])
k <- name
lexeme (char '.')
body <- parse_term
return $ Lam k body
parse_dup :: ReadP Term
parse_dup = do
lexeme (char '!')
k <- name
lexeme (char '&')
l <- name
lexeme (char '=')
v <- parse_term
lexeme (char ';')
t <- parse_term
return $ Dup k l v t
parse_sup :: ReadP Term
parse_sup = do
lexeme (char '&')
l <- name
lexeme (char '{')
a <- parse_term
lexeme (char ',')
b <- parse_term
lexeme (char '}')
return $ Sup l a b
parse_era :: ReadP Term
parse_era = lexeme (string "&{}") >> return Era
parse_var :: ReadP Term
parse_var = do
k <- name
choice
[ string "₀" >> return (Dp0 k)
, string "₁" >> return (Dp1 k)
, return (Var k)
]
parse_nam :: ReadP String
parse_nam = munch1 $ \c
-> c >= 'a' && c <= 'z'
|| c >= 'A' && c <= 'Z'
|| c >= '0' && c <= '9'
|| c == '_' || c == '/'
read_term :: String -> Term
read_term s = case readP_to_S (parse_term <* skipSpaces <* eof) s of
[(t, "")] -> t
_ -> error "bad-parse"
-- Evaluation
-- ==========
wnf :: Env -> Term -> (Env,Term)
wnf s t = go s t where
go s (App f x) = let (s0,f0) = wnf s f in app s0 f0 x
go s (Dup k l v t) = wnf (delay_dup s k (l,v)) t
go s (Var x) = var s x
go s (Dp0 x) = dp0 s x
go s (Dp1 x) = dp1 s x
go s f = (s,f)
app :: Env -> Term -> Term -> (Env,Term)
app s (Nam fk) x = app_nam s fk x
app s (Dry df dx) x = app_dry s df dx x
app s (Lam fk ff) x = app_lam s fk ff x
app s (Sup fl fa fb) x = app_sup s fl fa fb x
app s f x = (s , App f x)
dup :: Env -> String -> Lab -> Term -> Term -> (Env,Term)
dup s k l (Nam vk) t = dup_nam s k l vk t
dup s k l (Dry vf vx) t = dup_dry s k l vf vx t
dup s k l (Lam vk vf) t = dup_lam s k l vk vf t
dup s k l (Sup vl va vb) t = dup_sup s k l vl va vb t
dup s k l v t = (s , Dup k l v t)
-- Interactions
-- ============
-- (λx.f v)
-- ---------- app-lam
-- x ← v
-- f
app_lam s fx ff v =
let s0 = inc_inters s in
let s1 = subst_var s0 fx v in
wnf s1 ff
-- (&fL{fa,fb} v)
-- -------------------- app-sup
-- ! x &fL = v
-- &fL{(fa x₀),(fa x₁)}
app_sup s fL fa fb v =
let s0 = inc_inters s in
let (s1,x) = fresh_dup s0 in
let app0 = App fa (Dp0 x) in
let app1 = App fb (Dp1 x) in
let sup = Sup fL app0 app1 in
let dup = Dup x fL v sup in
wnf s1 dup
-- (fk v)
-- ------ app-nam
-- (fk v)
app_nam s fk v = (inc_inters s, Dry (Nam fk) v)
-- ((df dx) v)
-- ----------- app-dry
-- ((df dx) v)
app_dry s df dx v = (inc_inters s, Dry (Dry df dx) v)
-- ! k &L = λvk.vf; t
-- ------------------ dup-lam
-- k₀ ← λx0.g0
-- k₁ ← λx1.g1
-- vk ← &L{x0,x1}
-- ! g &L = vf
-- t
dup_lam s k l vk vf t =
let s0 = inc_inters s in
let (s1, x0) = fresh_var s0 in
let (s2, x1) = fresh_var s1 in
let (s3, g) = fresh_dup s2 in
let s4 = subst_dp0 s3 k (Lam x0 (Dp0 g)) in
let s5 = subst_dp1 s4 k (Lam x1 (Dp1 g)) in
let s6 = subst_var s5 vk (Sup l (Var x0) (Var x1)) in
let dup = Dup g l vf t in
wnf s6 dup
-- ! k &L = &vL{va,vb}; t
-- ---------------------- dup-sup (==)
-- if l == vL:
-- k₀ ← va
-- k₁ ← vb
-- t
-- else:
-- k₀ ← &vL{a₀,b₀}
-- k₁ ← &vL{a₁,b₁}
-- ! a &L = va
-- ! b &L = vb
-- t
dup_sup s k l vl va vb t
| l == vl =
let s0 = inc_inters s in
let s1 = subst_dp0 s0 k va in
let s2 = subst_dp1 s1 k vb in
wnf s2 t
| l /= vl =
let s0 = inc_inters s in
let (s1, a) = fresh_dup s0 in
let (s2, b) = fresh_dup s1 in
let s3 = subst_dp0 s2 k (Sup vl (Dp0 a) (Dp0 b)) in
let s4 = subst_dp1 s3 k (Sup vl (Dp1 a) (Dp1 b)) in
let dup = Dup a l va (Dup b l vb t) in
wnf s4 dup
-- ! k &L = vk; t
-- -------------- dup-nam
-- k₀ ← vk
-- k₁ ← vk
-- t
dup_nam s k l vk t =
let s0 = inc_inters s in
let s1 = subst_dp0 s0 k (Nam vk) in
let s2 = subst_dp1 s1 k (Nam vk) in
wnf s2 t
-- ! k &L = (vf vx); t
-- --------------------- dup-dry
-- ! f &L = vf
-- ! x &L = vx
-- k₀ ← (f₀ x₀)
-- k₁ ← (f₁ x₁)
-- t
dup_dry s k l vf vx t =
let s0 = inc_inters s in
let (s1, f) = fresh_dup s0 in
let (s2, x) = fresh_dup s1 in
let s3 = subst_dp0 s2 k (Dry (Dp0 f) (Dp0 x)) in
let s4 = subst_dp1 s3 k (Dry (Dp1 f) (Dp1 x)) in
let dup = Dup f l vf (Dup x l vx t) in
wnf s4 dup
-- x
-- ------------ var
-- var_map[x]
var :: Env -> String -> (Env,Term)
var s k = case take_var s k of
(Just t, s0) -> wnf s0 t
(Nothing, _) -> (s, Var k)
-- x₀
-- ---------- dp0
-- dp0_map[x]
dp0 :: Env -> String -> (Env,Term)
dp0 s k = case take_dp0 s k of
(Just t, s0) -> wnf s0 t
(Nothing, _) -> case take_dup s k of
(Just (l,v), s0) -> let (s1,v0) = wnf s0 v in dup s1 k l v0 (Dp0 k)
(Nothing, _) -> (s, Dp0 k)
-- x₁
-- ---------- dp1
-- dp1_map[x]
dp1 :: Env -> String -> (Env,Term)
dp1 s k = case take_dp1 s k of
(Just t, s0) -> wnf s0 t
(Nothing, _) -> case take_dup s k of
(Just (l,v), s0) -> let (s1,v0) = wnf s0 v in dup s1 k l v0 (Dp1 k)
(Nothing, _) -> (s, Dp1 k)
-- Normalization
-- =============
nf :: Env -> Term -> (Env,Term)
nf s x = let (s0,x0) = wnf s x in go s0 x0 where
go s (Nam k) = (s, Nam k)
go s (Dry f x) = let (s0,f0) = nf s f in let (s1,x0) = nf s0 x in (s1, Dry f0 x0)
go s (Var k) = (s, Var k)
go s (Dp0 k) = (s, Dp0 k)
go s (Dp1 k) = (s, Dp1 k)
go s Era = (s, Era)
go s (Sup l a b) = let (s0,a0) = nf s a in let (s1,b0) = nf s0 b in (s1, Sup l a0 b0)
go s (Dup k l v t) = let (s0,v0) = nf s v in let (s1,t0) = nf s0 t in (s1, Dup k l v0 t0)
go s (Lam k f) = let (s0,f0) = nf (subst_var s k (Nam k)) f in (s0, Lam k f0)
go s (Abs k f) = let (s0,f0) = nf s f in (s0, Abs k f0)
go s (App f x) = let (s0,f0) = nf s f in let (s1,x0) = nf s0 x in (s1, App f0 x0)
-- Main
-- ====
f :: Int -> String
f n = "λf. " ++ dups ++ final where
dups = concat [dup i | i <- [0..n-1]]
dup 0 = "!F00 &A = f;\n "
dup i = "!F" ++ pad i ++ " &A = λx" ++ pad (i-1) ++ ".(F" ++ pad (i-1) ++ "₀ (F" ++ pad (i-1) ++ "₁ x" ++ pad (i-1) ++ "));\n "
final = "λx" ++ pad (n-1) ++ ".(F" ++ pad (n-1) ++ "₀ (F" ++ pad (n-1) ++ "₁ x" ++ pad (n-1) ++ "))"
pad x = if x < 10 then "0" ++ show x else show x
term = read_term $ "((" ++ f 18 ++ " λX.((X λT0.λF0.F0) λT1.λF1.T1)) λT2.λF2.T2)"
main :: IO ()
main = do
let res = nf env term
print $ snd $ res
print $ inters $ fst $ res
PROBLEM:
while the code above works, it is too slow. currently, it takes about 4 seconds
to return the correct result, performing about 450k interactions per second. for
a perspective, similar implementations in C perform about 100m interactions per
second - a 200x difference.
GOAL:
refactor the file above to optimize it as much as possible.
IMPORTANT:
keep a parser that reads the term as a string (don't build it manually)
check that the result is still preserved. currently, it prints:
λ$bgqec.λ$bgqee.$bgqec
1835080
it is fine to print an equivalent λ-term (ex: λt.λf.t or λa.λb.a, etc.), but
it is NOT fine to print a different λ-term. the interaction count (1835080)
must also be the same, confirming we're running "the same computation".
HINTS:
- you could replace immutable subst maps by a massive mmap'ed array
- you could remove strings / names from runtime terms and re-add when printing
- you could use strict annotations and other pragmas to optimize memory usage
- most work is done by recursive wnf traversal make sure it run fast, perhaps
using tail calls would be possible here.
- using global states and making everything IO is faster than pair-passing
ultimately, it is your job to make it as fast as possible, while keeping the
same computation. you're free to ignore my suggestions and do anything you want!
now, write below a complete, refactored Haskell file that does the equivalent
computation, but runs much faster than the one above.
@aaravq
Copy link

aaravq commented Nov 7, 2025

Curr Checkpoint

To achieve the desired performance improvement, we refactor the solution to use mutable arrays (IOVector) and integer identifiers instead of strings and Map. We maintain the exact semantics and interaction count.

{-# LANGUAGE CPP #-}
{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE Strict #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# OPTIONS_GHC -O2 -funbox-strict-fields #-}

import Data.IORef
import Data.Maybe (fromJust)
import System.IO.Unsafe (unsafePerformIO)
import Text.ParserCombinators.ReadP
import qualified Data.Map.Strict as M
import qualified Data.Vector.Mutable as VM
import qualified Data.Vector.Unboxed.Mutable as UM
import Control.Monad (when, forM_)

-- Core Types using Int IDs
type NameId = Int

data Term
  = Nam {-# UNPACK #-} !NameId
  | Var {-# UNPACK #-} !NameId
  | Dp0 {-# UNPACK #-} !NameId
  | Dp1 {-# UNPACK #-} !NameId
  | Era
  | Sup {-# UNPACK #-} !NameId !Term !Term
  | Dup {-# UNPACK #-} !NameId {-# UNPACK #-} !NameId !Term !Term
  | Lam {-# UNPACK #-} !NameId !Term
  | Abs {-# UNPACK #-} !NameId !Term
  | Dry !Term !Term
  | App !Term !Term

instance Show Term where
  show t = unsafePerformIO (showTerm t)

-- Global Mutable State Environment
data GlobalEnv = GlobalEnv
  { inters_ref :: {-# UNPACK #-} !(IORef Int)
  , var_new_ref :: {-# UNPACK #-} !(IORef Int)
  , dup_new_ref :: {-# UNPACK #-} !(IORef Int)
  , next_id_ref :: {-# UNPACK #-} !(IORef Int)
  -- Mappings
  , var_vec :: {-# UNPACK #-} !(VM.IOVector Term)
  , dp0_vec :: {-# UNPACK #-} !(VM.IOVector Term)
  , dp1_vec :: {-# UNPACK #-} !(VM.IOVector Term)
  , dup_vec :: {-# UNPACK #-} !(VM.IOVector (Int,Term))
  -- Presence flags (0=empty, 1=full)
  , var_flg :: {-# UNPACK #-} !(UM.IOVector Int)
  , dp0_flg :: {-# UNPACK #-} !(UM.IOVector Int)
  , dp1_flg :: {-# UNPACK #-} !(UM.IOVector Int)
  , dup_flg :: {-# UNPACK #-} !(UM.IOVector Int)
  -- ID metadata for name reconstruction
  , id_kind  :: {-# UNPACK #-} !(UM.IOVector Int) -- 0=input, 1=var, 2=dup
  , id_idx   :: {-# UNPACK #-} !(UM.IOVector Int) -- sequence number
  , id_names :: {-# UNPACK #-} !(IORef (M.Map Int String)) -- input names
  }

-- Allocate global environment with sufficient capacity
cap :: Int
cap = 16777216 -- 16M slots

initEnv :: IO GlobalEnv
initEnv = do
  ir <- newIORef 0
  vn <- newIORef 0
  dn <- newIORef 0
  ni <- newIORef 0
  vv <- VM.unsafeNew cap
  d0 <- VM.unsafeNew cap
  d1 <- VM.unsafeNew cap
  dv <- VM.unsafeNew cap
  vf <- UM.replicate cap 0
  f0 <- UM.replicate cap 0
  f1 <- UM.replicate cap 0
  df <- UM.replicate cap 0
  ik <- UM.unsafeNew cap
  ii <- UM.unsafeNew cap
  nm <- newIORef M.empty
  return $ GlobalEnv ir vn dn ni vv d0 d1 dv vf f0 f1 df ik ii nm

env :: GlobalEnv
env = unsafePerformIO initEnv
{-# NOINLINE env #-}

-- Name Management
registerInputName :: String -> IO Int
registerInputName s = do
  nmMap <- readIORef (id_names env)
  case M.lookup s (inverseMap nmMap) of
    Just i -> return i
    Nothing -> do
      i <- readIORef (next_id_ref env)
      writeIORef (next_id_ref env) (i + 1)
      UM.unsafeWrite (id_kind env) i 0
      modifyIORef' (id_names env) (M.insert i s)
      return i
  where
    inverseMap m = M.fromList [ (b,a) | (a,b) <- M.toList m ]

freshVarId :: IO Int
freshVarId = do
  i <- readIORef (next_id_ref env)
  writeIORef (next_id_ref env) (i + 1)
  n <- readIORef (var_new_ref env)
  writeIORef (var_new_ref env) (n + 1)
  UM.unsafeWrite (id_kind env) i 1
  UM.unsafeWrite (id_idx env) i (n + 1)
  return i

freshDupId :: IO Int
freshDupId = do
  i <- readIORef (next_id_ref env)
  writeIORef (next_id_ref env) (i + 1)
  n <- readIORef (dup_new_ref env)
  writeIORef (dup_new_ref env) (n + 1)
  UM.unsafeWrite (id_kind env) i 2
  UM.unsafeWrite (id_idx env) i (n + 1)
  return i

-- Parsing with mutable setup
parse_term_io :: ReadP (IO Term)
parse_term_io = lexeme $ choice
  [ parse_lam_io, parse_dup_io, parse_app_io, parse_sup_io, parse_era_io, parse_var_io ]

parse_lam_io = do
  lexeme (choice [char 'λ', char '\\'])
  k <- parse_nam
  lexeme (char '.')
  body <- parse_term_io
  return $ do
    ki <- registerInputName k
    b <- body
    return (Lam ki b)

parse_dup_io = do
  lexeme (char '!')
  k <- parse_nam
  lexeme (char '&')
  l <- parse_nam
  lexeme (char '=')
  v <- parse_term_io
  lexeme (char ';')
  t <- parse_term_io
  return $ do
    ki <- registerInputName k
    li <- registerInputName l
    vi <- v
    ti <- t
    return (Dup ki li vi ti)

parse_app_io = do
  lexeme (char '(')
  ts <- many1 parse_term_io
  lexeme (char ')')
  return $ do
    ts' <- sequence ts
    case ts' of
      (h:t) -> return (Prelude.foldl App h t)
      _     -> error "parse_app empty"

parse_sup_io = do
  lexeme (char '&')
  l <- parse_nam
  lexeme (char '{')
  a <- parse_term_io
  lexeme (char ',')
  b <- parse_term_io
  lexeme (char '}')
  return $ do
    li <- registerInputName l
    ai <- a
    bi <- b
    return (Sup li ai bi)

parse_era_io = lexeme (string "&{}") >> return (return Era)

parse_var_io = do
  k <- parse_nam
  choice
    [ string "" >> return (do ki <- registerInputName k; return (Dp0 ki))
    , string "" >> return (do ki <- registerInputName k; return (Dp1 ki))
    , return (do ki <- registerInputName k; return (Var ki))
    ]

parse_nam :: ReadP String
parse_nam = lexeme $ munch1 $ \c -> c >= 'a' && c <= 'z' || c >= 'A' && c <= 'Z' || c >= '0' && c <= '9' || c == '_' || c == '/'

lexeme p = skipSpaces *> p

read_term :: String -> Term
read_term s = case readP_to_S (parse_term_io <* skipSpaces <* eof) s of
  [(act, "")] -> unsafePerformIO act
  _           -> error "bad-parse"

-- Evaluation
inc_inters :: IO ()
inc_inters = modifyIORef' (inters_ref env) (+1)

wnf :: Term -> IO Term
wnf (App f x) = do
  f0 <- wnf f
  app f0 x
wnf (Dup k l v t) = do
  UM.unsafeWrite (dup_flg env) k 1
  VM.unsafeWrite (dup_vec env) k (l,v)
  wnf t
wnf (Var k) = var k
wnf (Dp0 k) = dp0 k
wnf (Dp1 k) = dp1 k
wnf t = return t

app :: Term -> Term -> IO Term
app (Nam fk) x       = do inc_inters; return (Dry (Nam fk) x)
app (Dry df dx) x    = do inc_inters; return (Dry (Dry df dx) x)
app (Lam fk ff) x    = do
  inc_inters
  UM.unsafeWrite (var_flg env) fk 1
  VM.unsafeWrite (var_vec env) fk x
  wnf ff
app (Sup fl fa fb) x = do
  inc_inters
  ki <- freshDupId
  let app0 = App fa (Dp0 ki)
      app1 = App fb (Dp1 ki)
      sup  = Sup fl app0 app1
      dup  = Dup ki fl x sup
  wnf dup
app f x = return (App f x)

dup :: Int -> Int -> Term -> Term -> IO Term
dup k l (Nam vk) t = do
  inc_inters
  UM.unsafeWrite (dp0_flg env) k 1
  VM.unsafeWrite (dp0_vec env) k (Nam vk)
  UM.unsafeWrite (dp1_flg env) k 1
  VM.unsafeWrite (dp1_vec env) k (Nam vk)
  wnf t
dup k l (Dry vf vx) t = do
  inc_inters
  fi <- freshDupId
  xi <- freshDupId
  UM.unsafeWrite (dp0_flg env) k 1
  VM.unsafeWrite (dp0_vec env) k (Dry (Dp0 fi) (Dp0 xi))
  UM.unsafeWrite (dp1_flg env) k 1
  VM.unsafeWrite (dp1_vec env) k (Dry (Dp1 fi) (Dp1 xi))
  let dup1 = Dup xi l vx t
  let dup2 = Dup fi l vf dup1
  wnf dup2
dup k l (Lam vk vf) t = do
  inc_inters
  x0 <- freshVarId
  x1 <- freshVarId
  g <- freshDupId
  UM.unsafeWrite (dp0_flg env) k 1
  VM.unsafeWrite (dp0_vec env) k (Lam x0 (Dp0 g))
  UM.unsafeWrite (dp1_flg env) k 1
  VM.unsafeWrite (dp1_vec env) k (Lam x1 (Dp1 g))
  UM.unsafeWrite (var_flg env) vk 1
  VM.unsafeWrite (var_vec env) vk (Sup l (Var x0) (Var x1))
  let dp = Dup g l vf t
  wnf dp
dup k l (Sup vl va vb) t
  | l == vl = do
      inc_inters
      UM.unsafeWrite (dp0_flg env) k 1
      VM.unsafeWrite (dp0_vec env) k va
      UM.unsafeWrite (dp1_flg env) k 1
      VM.unsafeWrite (dp1_vec env) k vb
      wnf t
  | otherwise = do
      inc_inters
      a <- freshDupId
      b <- freshDupId
      UM.unsafeWrite (dp0_flg env) k 1
      VM.unsafeWrite (dp0_vec env) k (Sup vl (Dp0 a) (Dp0 b))
      UM.unsafeWrite (dp1_flg env) k 1
      VM.unsafeWrite (dp1_vec env) k (Sup vl (Dp1 a) (Dp1 b))
      let dp1 = Dup b l vb t
      let dp2 = Dup a l va dp1
      wnf dp2
dup k l v t = return (Dup k l v t)

var :: Int -> IO Term
var k = do
  flg <- UM.unsafeRead (var_flg env) k
  if flg == 1 then do
    val <- VM.unsafeRead (var_vec env) k
    UM.unsafeWrite (var_flg env) k 0
    wnf val
  else return (Var k)

dp0 :: Int -> IO Term
dp0 k = do
  flg <- UM.unsafeRead (dp0_flg env) k
  if flg == 1 then do
    val <- VM.unsafeRead (dp0_vec env) k
    UM.unsafeWrite (dp0_flg env) k 0
    wnf val
  else do
    dflg <- UM.unsafeRead (dup_flg env) k
    if dflg == 1 then do
      (l, v) <- VM.unsafeRead (dup_vec env) k
      UM.unsafeWrite (dup_flg env) k 0
      v0 <- wnf v
      dup k l v0 (Dp0 k)
    else return (Dp0 k)

dp1 :: Int -> IO Term
dp1 k = do
  flg <- UM.unsafeRead (dp1_flg env) k
  if flg == 1 then do
    val <- VM.unsafeRead (dp1_vec env) k
    UM.unsafeWrite (dp1_flg env) k 0
    wnf val
  else do
    dflg <- UM.unsafeRead (dup_flg env) k
    if dflg == 1 then do
      (l, v) <- VM.unsafeRead (dup_vec env) k
      UM.unsafeWrite (dup_flg env) k 0
      v0 <- wnf v
      dup k l v0 (Dp1 k)
    else return (Dp1 k)

nf :: Term -> IO Term
nf t = do
  t' <- wnf t
  case t' of
    Nam k       -> return (Nam k)
    Dry f x     -> Dry <$> nf f <*> nf x
    Var k       -> return (Var k)
    Dp0 k       -> return (Dp0 k)
    Dp1 k       -> return (Dp1 k)
    Era         -> return Era
    Sup l a b   -> Sup l <$> nf a <*> nf b
    Dup k l v b -> Dup k l <$> nf v <*> nf b
    Lam k b     -> do
      UM.unsafeWrite (var_flg env) k 1
      VM.unsafeWrite (var_vec env) k (Nam k)
      Lam k <$> nf b
    Abs k b     -> Abs k <$> nf b
    App f x     -> App <$> nf f <*> nf x

-- Display helpers
nameFrom :: String -> Int -> String
nameFrom chars n = build n "" where
  base = length chars
  build k acc | k <= 0    = acc
              | otherwise = let (q,r) = (k - 1) `divMod` base in build q (chars !! r : acc)

idToString :: Int -> IO String
idToString i = do
  kind <- UM.unsafeRead (id_kind env) i
  case kind of
    0 -> do
      names <- readIORef (id_names env)
      return $ fromJust $ M.lookup i names
    1 -> do
      idx <- UM.unsafeRead (id_idx env) i
      return $ "$" ++ nameFrom ['a'..'z'] idx
    2 -> do
      idx <- UM.unsafeRead (id_idx env) i
      return $ "$" ++ nameFrom ['A'..'Z'] idx
    _ -> error "unknown id kind"

showTerm :: Term -> IO String
showTerm (Nam k)       = idToString k
showTerm (Dry f x)     = do fs <- showTerm f; xs <- showTerm x; return $ "(" ++ fs ++ " " ++ xs ++ ")"
showTerm (Var k)       = idToString k
showTerm (Dp0 k)       = do ks <- idToString k; return $ ks ++ ""
showTerm (Dp1 k)       = do ks <- idToString k; return $ ks ++ ""
showTerm Era           = return "&{}"
showTerm (Sup l a b)   = do ls <- idToString l; as <- showTerm a; bs <- showTerm b; return $ "&" ++ ls ++ "{" ++ as ++ "," ++ bs ++ "}"
showTerm (Dup k l v t) = do ks <- idToString k; ls <- idToString l; vs <- showTerm v; ts <- showTerm t; return $ "!" ++ ks ++ "&" ++ ls ++ "=" ++ vs ++ ";" ++ ts
showTerm (Lam k f)     = do ks <- idToString k; fs <- showTerm f; return $ "λ" ++ ks ++ "." ++ fs
showTerm (Abs k f)     = do ks <- idToString k; fs <- showTerm f; return $ "λ" ++ ks ++ "." ++ fs
showTerm (App f x)     = do fs <- showTerm f; xs <- showTerm x; return $ "(" ++ fs ++ " " ++ xs ++ ")"

-- Main logic
f :: Int -> String
f n = "λf. " ++ dups ++ final where
  dups  = concat [dup i | i <- [0..n-1]]
  dup 0 = "!F00 &A = f;\n    "
  dup i = "!F" ++ pad i ++ " &A = λx" ++ pad (i-1) ++ ".(F" ++ pad (i-1) ++ "₀ (F" ++ pad (i-1) ++ "₁ x" ++ pad (i-1) ++ "));\n    "
  final = "λx" ++ pad (n-1) ++ ".(F" ++ pad (n-1) ++ "₀ (F" ++ pad (n-1) ++ "₁ x" ++ pad (n-1) ++ "))"
  pad x = if x < 10 then "0" ++ show x else show x

inputStr :: String
inputStr = "((" ++ f 18 ++ " λX.((X λT0.λF0.F0) λT1.λF1.T1)) λT2.λF2.T2)"

main :: IO ()
main = do
  let t = read_term inputStr
  res <- nf t
  sres <- showTerm res
  putStrLn sres
  ic <- readIORef (inters_ref env)
  print ic

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment