Last active
August 29, 2015 14:05
-
-
Save kputnam/3364aed303fa9fc301ec to your computer and use it in GitHub Desktop.
Demo of Union-Find algorithm used for solving type constraints
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
module Demo where | |
import Control.Applicative | |
import Control.Monad.ST | |
import UnionFind | |
import Weight | |
-- | Return the class label if the class is represented by a constructed type | |
con_ :: Class s (Weight a) b -> ST s (Maybe b) | |
con_ c = ((== Con) . tag <$> weight c) >>= if_ | |
where if_ True = Just <$> label c | |
if_ _ = return Nothing | |
x :: (String, Weight String) | |
x = runST $ do | |
-- Note we have to take care not to allow the same type to belong to | |
-- multiple equivalence classes. Possibly by constructing a map from | |
-- Type to @Class s (Weight Type) Type@. | |
a <- weighted (var "a") "a" | |
b <- weighted (var "b") "b" | |
bool <- weighted (con "Bool") "Bool" | |
char <- weighted (con "Char") "Char" | |
-- When unifying two equivalence classes, both represented by a saturated type | |
-- constructor, we (as the user) need to check that the constructors match and | |
-- recursively unify the corresponding arguments to the type constructor for | |
-- constructors like Maybe, [], and (->). Skipping that part for now (we only | |
-- have nullary type constructors in this example) | |
-- This should succeed (both aren't Con, so ok is Nothing) | |
ok <- liftA2 (==) <$> con_ a <*> con_ bool | |
case ok of | |
Just False -> fail a bool -- two Cons with inequal constructors | |
_ -> union a bool -- Var and Con, or Con and Var, or two Cons with equal constructors | |
-- This fails because a = bool from above, and now a = char implies bool = char. | |
ok <- liftA2 (==) <$> con_ a <*> con_ char | |
case ok of | |
Just False -> fail a char | |
_ -> union a char | |
ok <- liftA2 (==) <$> con_ b <*> con_ a | |
case ok of | |
Just False -> fail b a | |
_ -> union b a | |
-- Assuming the illegal second unification is commented-out, | |
-- => ("Bool",Weight {tag = Con, size = 3, elems = DL {unDL = fromList ["b","a","Bool"]}}) | |
liftA2 (,) (label a) (weight a) | |
where | |
abort x y = do | |
x' <- label x | |
y' <- label y | |
fail $ unwords ["can't unify", x', y'] |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
module UnionFind | |
( Class | |
, label | |
, weight | |
, singleton | |
, weighted | |
, union | |
, find | |
, connected | |
) where | |
import Data.Maybe | |
import Data.STRef | |
import Data.Semigroup | |
import Control.Monad.ST | |
import Control.Applicative | |
-- | |
-- runST $ do | |
-- a <- singleton 'a' | |
-- b <- singleton 'b' | |
-- c <- singleton 'c' | |
-- d <- singleton 'd' | |
-- | |
-- union a d | |
-- union b a | |
-- | |
-- liftA3 (,,) (label a) (label b) (label d) | |
-- | |
data Class s m a = | |
Class | |
{ label_ :: a | |
, parent :: STRef s (Maybe (Class s m a)) | |
, weight_ :: STRef s m } | |
-- | The label which represents the entire class of elements | |
label :: Class s m a -> ST s a | |
label = fmap label_ . find | |
-- | The weight (comparable size) of the entire class of elements | |
weight :: Class s m a -> ST s m | |
weight k = readSTRef . weight_ =<< find k | |
-- Elements in the graph are uniquely identified by their `parent` field, which | |
-- is a reference to another Class. We're not comparing what is being referred | |
-- to (which changes), but the references themselves (which do not change). | |
-- | |
-- let a = singleton 'a' in | |
-- runST $ (\x -> x == x) a -- True | |
-- runST $ liftA2 (==) a a -- False | |
-- | |
instance Eq (Class s m a) where | |
a == b = parent a == parent b | |
-- | Use additive integer class sizes | |
singleton :: a -> ST s (Class s (Sum Int) a) | |
singleton a = Class a <$> newSTRef Nothing <*> newSTRef mempty | |
-- | Use user-defined class size data type | |
weighted :: (Semigroup m, Ord m) => m -> a -> ST s (Class s m a) | |
weighted m a = Class a <$> newSTRef Nothing <*> newSTRef m | |
-- | Returns a reference to the representative element of @k@'s equivalence class. | |
-- The given class's `parent` reference may be updated (to reduce indirection). | |
-- Runs in O(1) amortized time. | |
find :: Class s m a -> ST s (Class s m a) | |
find k = aux k =<< readSTRef (parent k) | |
where | |
aux k Nothing = return k | |
aux k (Just p) = do | |
-- Does k have a grandparent? | |
g <- readSTRef (parent p) | |
case g of | |
Nothing -> | |
return p -- No, p is the end of the path | |
Just _ -> do | |
writeSTRef (parent k) g -- Yes, link k to g, skipping over p, | |
aux p g -- then try linking p to its grandparent | |
-- | Join the equivalence classes of @j@ and @k@, and return a reference to the | |
-- element representing the class containing @j@ and @k@. Runs in O(1) amortized | |
-- time. | |
union :: (Semigroup m, Ord m) => Class s m a -> Class s m a -> ST s (Class s m a) | |
union j k | |
| j == k = return j | |
| otherwise = do | |
rootJ <- find j | |
rootK <- find k | |
-- Check if these are disjoint classes | |
if rootJ /= rootK | |
then choose rootJ rootK | |
else return rootJ | |
where | |
choose rootJ rootK = do | |
weightJ <- readSTRef (weight_ rootJ) | |
weightK <- readSTRef (weight_ rootK) | |
-- Update the smaller class to point to the larger | |
case weightJ `compare` weightK of | |
LT -> link rootJ rootK (weightJ <> weightK) | |
_ -> link rootK rootJ (weightJ <> weightK) | |
link small large m = do | |
writeSTRef (parent small) (Just large) | |
writeSTRef (weight_ large) m | |
return large | |
-- | @True@ if @j@ and @k@ belong to the same equivalence class. Runs | |
-- in O(1) amortized time. | |
connected :: Class s m a -> Class s m a -> ST s Bool | |
connected j k | |
| j == k = return True -- trivial case | |
| otherwise = do | |
-- Maybe (Class s m a) | |
parentJ <- readSTRef (parent j) | |
parentK <- readSTRef (parent k) | |
-- If either parent is Nothing, then j or k were unequal representative elements | |
-- (remember, we know j /= k from earlier). Otherwise, see if they refer to the | |
-- same representative element. | |
fromMaybe (return False) (compare <$> parentJ <*> parentK) | |
where | |
compare :: Class s m a -> Class s m a -> ST s Bool | |
compare j k = (==) <$> find j <*> find k |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{-# LANGUAGE DeriveFunctor #-} | |
module Weight | |
where | |
import qualified Data.DList as DL | |
import Data.Semigroup | |
-- We're using DList to represent a collection of types (variables and saturated | |
-- type constructors like Int, [a], etc) because it offers O(1) list merging. It | |
-- does cost O(n) to convert it into a "real" list, but it's only paid if needed | |
-- by the user (for example, to report diagnostic info when unification fails) | |
newtype DL a = DL { unDL :: DL.DList a } | |
deriving (Eq, Show, Read, Functor) | |
instance Semigroup (DL a) where | |
DL a <> DL b = DL (DL.append a b) | |
-- | Represent equivalence classes by either a type variable or saturated | |
-- type constructor | |
data Tag = Var | Con | |
deriving (Eq, Show, Read) | |
-- | Prefer fully-saturated type constructors as representative elements of | |
-- an equivalence class rather than variables. | |
instance Ord Tag where | |
Var `compare` Con = LT | |
Con `compare` Var = GT | |
_ `compare` _ = EQ | |
-- | Metadata attached to each equivalence class. We not only track how the | |
-- class is represented (with a Var or Con) but also the number of elements | |
-- in the class and each element in the class (using DList). The UnionFind | |
-- algorithm doesn't care about these details, but we do for type inference. | |
data Weight a | |
= Weight { tag :: Tag, size :: Int, elems :: DL a } | |
deriving (Eq, Show, Read) | |
-- | The union operation will update the "smaller" of two equivalence classes | |
-- according to this ordering. We prefer to represent equivalence classes by | |
-- saturated type constructors over variables, and then break ties by choosing | |
-- the class with fewer elements. This tie-breaking should reduce the number | |
-- of references that need to be update when we traverse using the @find@ | |
-- operation later on. | |
instance Eq a => Ord (Weight a) where | |
Weight a asize _ `compare` Weight b bsize _ = | |
compare a b <> compare asize bsize | |
-- | This defines how to compute the Weight of the class formed by unioning | |
-- two other equivalence classes. We use the tag from the "bigger" class, | |
-- sum the sizes of the two classes, and merge the two lists of elements. | |
instance Semigroup (Weight a) where | |
Weight atag asize as <> Weight btag bsize bs = | |
Weight (max atag btag) (asize + bsize) (as <> bs) | |
-- | The weight of an equivalence class containing a single type variable | |
var :: a -> Weight a | |
var a = Weight Var 1 (DL (DL.singleton a)) | |
-- | The weight of an equivalence class containing a single saturated type constructor | |
con :: a -> Weight a | |
con a = Weight Con 1 (DL (DL.singleton a)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment