Created
August 17, 2020 19:58
-
-
Save mniip/aa4534f13361f64aa80ec620657958f1 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{-# LANGUAGE BangPatterns, RoleAnnotations, LambdaCase, ViewPatterns #-} | |
import Control.Monad | |
import Control.Monad.ST | |
import Control.Monad.ST.Unsafe | |
import Data.STRef | |
import Debug.Trace | |
import GHC.IO | |
newtype Cell s a = Cell (STRef s a) | |
type role Cell nominal nominal | |
unsafeDupablePerformST :: ST s a -> a | |
unsafeDupablePerformST = unsafeDupablePerformIO . unsafeSTToIO | |
newCell :: a -> ST s (Cell s a, () -> a, a -> ()) | |
newCell x = do | |
r <- newSTRef x | |
pure (Cell r, \_ -> unsafeDupablePerformST $ readSTRef r, unsafeDupablePerformST . writeSTRef r) | |
-- we don't care if the accesses/updates get duplicated | |
readCell :: Cell s a -> ST s a | |
readCell (Cell r) = readSTRef r | |
writeCell :: Cell s a -> a -> ST s () | |
writeCell (Cell r) = writeSTRef r | |
data UF s a = UF | |
{ ufValue :: !a | |
, ufParentCell :: Cell s (Maybe (UF s a)) | |
, ufParentAccess :: () -> Maybe (UF s a) | |
, ufParentUpdate :: Maybe (UF s a) -> () | |
, ufSizeRef :: STRef s Int | |
} | |
ufFindRootCompression :: UF s a -> UF s a | |
ufFindRootCompression uf = case ufParentAccess uf () of | |
Nothing -> uf | |
Just uf' | !root <- go uf' -> ufParentUpdate uf (Just root) `seq` root | |
where | |
go uf' = case ufParentAccess uf' () of | |
Nothing -> uf' | |
Just uf'' -> go uf'' | |
ufFindRootHalving :: UF s a -> UF s a | |
ufFindRootHalving uf = case ufParentAccess uf () of | |
Nothing -> uf | |
Just uf' -> case ufParentAccess uf' () of | |
Nothing -> uf' | |
Just uf'' -> ufParentUpdate uf (Just uf'') `seq` ufFindRootHalving uf'' | |
ufFindRootSplitting :: UF s a -> UF s a | |
ufFindRootSplitting uf = case ufParentAccess uf () of | |
Nothing -> uf | |
Just uf' -> go uf uf' | |
where | |
go uf' uf'' = case ufParentAccess uf'' () of | |
Nothing -> uf'' | |
Just uf''' -> ufParentUpdate uf' (Just uf''') `seq` go uf'' uf''' | |
ufDebugRoot :: UF s a -> (UF s a, Int) | |
ufDebugRoot uf = case ufParentAccess uf () of | |
Nothing -> (uf, 0) | |
Just uf' -> succ <$> ufDebugRoot uf' | |
ufFindRoot :: UF s a -> UF s a | |
--ufFindRoot = ufFindRootCompression | |
ufFindRoot = ufFindRootHalving | |
--ufFindRoot = ufFindRootSplitting | |
makeSet :: a -> ST s (UF s a) | |
makeSet x = do | |
sz <- newSTRef 0 | |
(p, pa, pu) <- newCell Nothing | |
pure $ UF | |
{ ufValue = x | |
, ufParentCell = p | |
, ufParentAccess = pa | |
, ufParentUpdate = pu | |
, ufSizeRef = sz | |
} | |
find :: UF s a -> a | |
find = ufValue . ufFindRoot | |
union :: UF s a -> UF s a -> ST s () | |
union uf1 uf2 = do | |
root1 <- seqFindRoot (ufFindRoot uf1) | |
-- maybe the ufFindRoot computation got memoized and contains outdated data | |
root2 <- seqFindRoot (ufFindRoot uf2) | |
sz1 <- readSTRef $ ufSizeRef root1 | |
sz2 <- readSTRef $ ufSizeRef root2 | |
if sz1 < sz2 | |
then attach root1 root2 $ sz1 + sz2 | |
else attach root2 root1 $ sz1 + sz2 | |
where | |
seqFindRoot uf = readCell (ufParentCell uf) >>= \case | |
Nothing -> pure uf | |
Just uf' -> seqFindRoot $ ufFindRoot uf' | |
attach r1 r2 sz = do | |
writeSTRef (ufSizeRef r2) sz | |
writeCell (ufParentCell r1) $ Just r2 | |
main :: IO () | |
main = do | |
cs <- map words . lines <$> getContents | |
evaluate $ runST $ do | |
ufs <- mapM makeSet [0..255] | |
forM_ cs $ \case | |
["find", read -> i] -> do | |
traceM $ "(depth " ++ show (snd . ufDebugRoot $ ufs !! i) ++ ") " ++ show (find $ ufs !! i) | |
["union", read -> i, read -> j] -> do | |
union (ufs !! i) (ufs !! j) | |
traceM $ "(new root " ++ show (ufValue . fst . ufDebugRoot $ ufs !! i) ++ ")" | |
{- | |
$ ./UF | |
union 0 1 | |
(new root 0) | |
union 2 3 | |
(new root 2) | |
union 4 5 | |
(new root 4) | |
union 6 7 | |
(new root 6) | |
union 8 9 | |
(new root 8) | |
union 10 11 | |
(new root 10) | |
union 12 13 | |
(new root 12) | |
union 14 15 | |
(new root 14) | |
union 1 3 | |
(new root 0) | |
union 5 7 | |
(new root 4) | |
union 9 11 | |
(new root 8) | |
union 13 15 | |
(new root 12) | |
union 0 4 | |
(new root 0) | |
union 8 12 | |
(new root 8) | |
union 1 9 | |
(new root 0) | |
find 15 | |
(depth 4) 0 | |
find 15 | |
(depth 2) 0 | |
find 15 | |
(depth 1) 0 | |
find 7 | |
(depth 3) 0 | |
find 7 | |
(depth 2) 0 | |
find 7 | |
(depth 1) 0 | |
find 14 | |
(depth 2) 0 | |
find 14 | |
(depth 1) 0 | |
-} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment