Skip to content

Instantly share code, notes, and snippets.

@kputnam
Last active August 29, 2015 14:05
Show Gist options
  • Save kputnam/3364aed303fa9fc301ec to your computer and use it in GitHub Desktop.
Save kputnam/3364aed303fa9fc301ec to your computer and use it in GitHub Desktop.
Demo of Union-Find algorithm used for solving type constraints
module Demo where
import Control.Applicative
import Control.Monad.ST
import UnionFind
import Weight
-- | Return the class label if the class is represented by a constructed type
con_ :: Class s (Weight a) b -> ST s (Maybe b)
con_ c = ((== Con) . tag <$> weight c) >>= if_
where if_ True = Just <$> label c
if_ _ = return Nothing
x :: (String, Weight String)
x = runST $ do
-- Note we have to take care not to allow the same type to belong to
-- multiple equivalence classes. Possibly by constructing a map from
-- Type to @Class s (Weight Type) Type@.
a <- weighted (var "a") "a"
b <- weighted (var "b") "b"
bool <- weighted (con "Bool") "Bool"
char <- weighted (con "Char") "Char"
-- When unifying two equivalence classes, both represented by a saturated type
-- constructor, we (as the user) need to check that the constructors match and
-- recursively unify the corresponding arguments to the type constructor for
-- constructors like Maybe, [], and (->). Skipping that part for now (we only
-- have nullary type constructors in this example)
-- This should succeed (both aren't Con, so ok is Nothing)
ok <- liftA2 (==) <$> con_ a <*> con_ bool
case ok of
Just False -> fail a bool -- two Cons with inequal constructors
_ -> union a bool -- Var and Con, or Con and Var, or two Cons with equal constructors
-- This fails because a = bool from above, and now a = char implies bool = char.
ok <- liftA2 (==) <$> con_ a <*> con_ char
case ok of
Just False -> fail a char
_ -> union a char
ok <- liftA2 (==) <$> con_ b <*> con_ a
case ok of
Just False -> fail b a
_ -> union b a
-- Assuming the illegal second unification is commented-out,
-- => ("Bool",Weight {tag = Con, size = 3, elems = DL {unDL = fromList ["b","a","Bool"]}})
liftA2 (,) (label a) (weight a)
where
abort x y = do
x' <- label x
y' <- label y
fail $ unwords ["can't unify", x', y']
module UnionFind
( Class
, label
, weight
, singleton
, weighted
, union
, find
, connected
) where
import Data.Maybe
import Data.STRef
import Data.Semigroup
import Control.Monad.ST
import Control.Applicative
--
-- runST $ do
-- a <- singleton 'a'
-- b <- singleton 'b'
-- c <- singleton 'c'
-- d <- singleton 'd'
--
-- union a d
-- union b a
--
-- liftA3 (,,) (label a) (label b) (label d)
--
data Class s m a =
Class
{ label_ :: a
, parent :: STRef s (Maybe (Class s m a))
, weight_ :: STRef s m }
-- | The label which represents the entire class of elements
label :: Class s m a -> ST s a
label = fmap label_ . find
-- | The weight (comparable size) of the entire class of elements
weight :: Class s m a -> ST s m
weight k = readSTRef . weight_ =<< find k
-- Elements in the graph are uniquely identified by their `parent` field, which
-- is a reference to another Class. We're not comparing what is being referred
-- to (which changes), but the references themselves (which do not change).
--
-- let a = singleton 'a' in
-- runST $ (\x -> x == x) a -- True
-- runST $ liftA2 (==) a a -- False
--
instance Eq (Class s m a) where
a == b = parent a == parent b
-- | Use additive integer class sizes
singleton :: a -> ST s (Class s (Sum Int) a)
singleton a = Class a <$> newSTRef Nothing <*> newSTRef mempty
-- | Use user-defined class size data type
weighted :: (Semigroup m, Ord m) => m -> a -> ST s (Class s m a)
weighted m a = Class a <$> newSTRef Nothing <*> newSTRef m
-- | Returns a reference to the representative element of @k@'s equivalence class.
-- The given class's `parent` reference may be updated (to reduce indirection).
-- Runs in O(1) amortized time.
find :: Class s m a -> ST s (Class s m a)
find k = aux k =<< readSTRef (parent k)
where
aux k Nothing = return k
aux k (Just p) = do
-- Does k have a grandparent?
g <- readSTRef (parent p)
case g of
Nothing ->
return p -- No, p is the end of the path
Just _ -> do
writeSTRef (parent k) g -- Yes, link k to g, skipping over p,
aux p g -- then try linking p to its grandparent
-- | Join the equivalence classes of @j@ and @k@, and return a reference to the
-- element representing the class containing @j@ and @k@. Runs in O(1) amortized
-- time.
union :: (Semigroup m, Ord m) => Class s m a -> Class s m a -> ST s (Class s m a)
union j k
| j == k = return j
| otherwise = do
rootJ <- find j
rootK <- find k
-- Check if these are disjoint classes
if rootJ /= rootK
then choose rootJ rootK
else return rootJ
where
choose rootJ rootK = do
weightJ <- readSTRef (weight_ rootJ)
weightK <- readSTRef (weight_ rootK)
-- Update the smaller class to point to the larger
case weightJ `compare` weightK of
LT -> link rootJ rootK (weightJ <> weightK)
_ -> link rootK rootJ (weightJ <> weightK)
link small large m = do
writeSTRef (parent small) (Just large)
writeSTRef (weight_ large) m
return large
-- | @True@ if @j@ and @k@ belong to the same equivalence class. Runs
-- in O(1) amortized time.
connected :: Class s m a -> Class s m a -> ST s Bool
connected j k
| j == k = return True -- trivial case
| otherwise = do
-- Maybe (Class s m a)
parentJ <- readSTRef (parent j)
parentK <- readSTRef (parent k)
-- If either parent is Nothing, then j or k were unequal representative elements
-- (remember, we know j /= k from earlier). Otherwise, see if they refer to the
-- same representative element.
fromMaybe (return False) (compare <$> parentJ <*> parentK)
where
compare :: Class s m a -> Class s m a -> ST s Bool
compare j k = (==) <$> find j <*> find k
{-# LANGUAGE DeriveFunctor #-}
module Weight
where
import qualified Data.DList as DL
import Data.Semigroup
-- We're using DList to represent a collection of types (variables and saturated
-- type constructors like Int, [a], etc) because it offers O(1) list merging. It
-- does cost O(n) to convert it into a "real" list, but it's only paid if needed
-- by the user (for example, to report diagnostic info when unification fails)
newtype DL a = DL { unDL :: DL.DList a }
deriving (Eq, Show, Read, Functor)
instance Semigroup (DL a) where
DL a <> DL b = DL (DL.append a b)
-- | Represent equivalence classes by either a type variable or saturated
-- type constructor
data Tag = Var | Con
deriving (Eq, Show, Read)
-- | Prefer fully-saturated type constructors as representative elements of
-- an equivalence class rather than variables.
instance Ord Tag where
Var `compare` Con = LT
Con `compare` Var = GT
_ `compare` _ = EQ
-- | Metadata attached to each equivalence class. We not only track how the
-- class is represented (with a Var or Con) but also the number of elements
-- in the class and each element in the class (using DList). The UnionFind
-- algorithm doesn't care about these details, but we do for type inference.
data Weight a
= Weight { tag :: Tag, size :: Int, elems :: DL a }
deriving (Eq, Show, Read)
-- | The union operation will update the "smaller" of two equivalence classes
-- according to this ordering. We prefer to represent equivalence classes by
-- saturated type constructors over variables, and then break ties by choosing
-- the class with fewer elements. This tie-breaking should reduce the number
-- of references that need to be update when we traverse using the @find@
-- operation later on.
instance Eq a => Ord (Weight a) where
Weight a asize _ `compare` Weight b bsize _ =
compare a b <> compare asize bsize
-- | This defines how to compute the Weight of the class formed by unioning
-- two other equivalence classes. We use the tag from the "bigger" class,
-- sum the sizes of the two classes, and merge the two lists of elements.
instance Semigroup (Weight a) where
Weight atag asize as <> Weight btag bsize bs =
Weight (max atag btag) (asize + bsize) (as <> bs)
-- | The weight of an equivalence class containing a single type variable
var :: a -> Weight a
var a = Weight Var 1 (DL (DL.singleton a))
-- | The weight of an equivalence class containing a single saturated type constructor
con :: a -> Weight a
con a = Weight Con 1 (DL (DL.singleton a))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment