Skip to content

Instantly share code, notes, and snippets.

@mniip
Created August 17, 2020 19:58
Show Gist options
  • Save mniip/aa4534f13361f64aa80ec620657958f1 to your computer and use it in GitHub Desktop.
Save mniip/aa4534f13361f64aa80ec620657958f1 to your computer and use it in GitHub Desktop.
{-# LANGUAGE BangPatterns, RoleAnnotations, LambdaCase, ViewPatterns #-}
import Control.Monad
import Control.Monad.ST
import Control.Monad.ST.Unsafe
import Data.STRef
import Debug.Trace
import GHC.IO
newtype Cell s a = Cell (STRef s a)
type role Cell nominal nominal
unsafeDupablePerformST :: ST s a -> a
unsafeDupablePerformST = unsafeDupablePerformIO . unsafeSTToIO
newCell :: a -> ST s (Cell s a, () -> a, a -> ())
newCell x = do
r <- newSTRef x
pure (Cell r, \_ -> unsafeDupablePerformST $ readSTRef r, unsafeDupablePerformST . writeSTRef r)
-- we don't care if the accesses/updates get duplicated
readCell :: Cell s a -> ST s a
readCell (Cell r) = readSTRef r
writeCell :: Cell s a -> a -> ST s ()
writeCell (Cell r) = writeSTRef r
data UF s a = UF
{ ufValue :: !a
, ufParentCell :: Cell s (Maybe (UF s a))
, ufParentAccess :: () -> Maybe (UF s a)
, ufParentUpdate :: Maybe (UF s a) -> ()
, ufSizeRef :: STRef s Int
}
ufFindRootCompression :: UF s a -> UF s a
ufFindRootCompression uf = case ufParentAccess uf () of
Nothing -> uf
Just uf' | !root <- go uf' -> ufParentUpdate uf (Just root) `seq` root
where
go uf' = case ufParentAccess uf' () of
Nothing -> uf'
Just uf'' -> go uf''
ufFindRootHalving :: UF s a -> UF s a
ufFindRootHalving uf = case ufParentAccess uf () of
Nothing -> uf
Just uf' -> case ufParentAccess uf' () of
Nothing -> uf'
Just uf'' -> ufParentUpdate uf (Just uf'') `seq` ufFindRootHalving uf''
ufFindRootSplitting :: UF s a -> UF s a
ufFindRootSplitting uf = case ufParentAccess uf () of
Nothing -> uf
Just uf' -> go uf uf'
where
go uf' uf'' = case ufParentAccess uf'' () of
Nothing -> uf''
Just uf''' -> ufParentUpdate uf' (Just uf''') `seq` go uf'' uf'''
ufDebugRoot :: UF s a -> (UF s a, Int)
ufDebugRoot uf = case ufParentAccess uf () of
Nothing -> (uf, 0)
Just uf' -> succ <$> ufDebugRoot uf'
ufFindRoot :: UF s a -> UF s a
--ufFindRoot = ufFindRootCompression
ufFindRoot = ufFindRootHalving
--ufFindRoot = ufFindRootSplitting
makeSet :: a -> ST s (UF s a)
makeSet x = do
sz <- newSTRef 0
(p, pa, pu) <- newCell Nothing
pure $ UF
{ ufValue = x
, ufParentCell = p
, ufParentAccess = pa
, ufParentUpdate = pu
, ufSizeRef = sz
}
find :: UF s a -> a
find = ufValue . ufFindRoot
union :: UF s a -> UF s a -> ST s ()
union uf1 uf2 = do
root1 <- seqFindRoot (ufFindRoot uf1)
-- maybe the ufFindRoot computation got memoized and contains outdated data
root2 <- seqFindRoot (ufFindRoot uf2)
sz1 <- readSTRef $ ufSizeRef root1
sz2 <- readSTRef $ ufSizeRef root2
if sz1 < sz2
then attach root1 root2 $ sz1 + sz2
else attach root2 root1 $ sz1 + sz2
where
seqFindRoot uf = readCell (ufParentCell uf) >>= \case
Nothing -> pure uf
Just uf' -> seqFindRoot $ ufFindRoot uf'
attach r1 r2 sz = do
writeSTRef (ufSizeRef r2) sz
writeCell (ufParentCell r1) $ Just r2
main :: IO ()
main = do
cs <- map words . lines <$> getContents
evaluate $ runST $ do
ufs <- mapM makeSet [0..255]
forM_ cs $ \case
["find", read -> i] -> do
traceM $ "(depth " ++ show (snd . ufDebugRoot $ ufs !! i) ++ ") " ++ show (find $ ufs !! i)
["union", read -> i, read -> j] -> do
union (ufs !! i) (ufs !! j)
traceM $ "(new root " ++ show (ufValue . fst . ufDebugRoot $ ufs !! i) ++ ")"
{-
$ ./UF
union 0 1
(new root 0)
union 2 3
(new root 2)
union 4 5
(new root 4)
union 6 7
(new root 6)
union 8 9
(new root 8)
union 10 11
(new root 10)
union 12 13
(new root 12)
union 14 15
(new root 14)
union 1 3
(new root 0)
union 5 7
(new root 4)
union 9 11
(new root 8)
union 13 15
(new root 12)
union 0 4
(new root 0)
union 8 12
(new root 8)
union 1 9
(new root 0)
find 15
(depth 4) 0
find 15
(depth 2) 0
find 15
(depth 1) 0
find 7
(depth 3) 0
find 7
(depth 2) 0
find 7
(depth 1) 0
find 14
(depth 2) 0
find 14
(depth 1) 0
-}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment