Finite sets in Haskell

This assumes familiarity with kinds and GADTs. Here's the extensions you're gonna need:

{-# LANGUAGE DataKinds #-}
{-# LANGUAGE KindSignatures #-}
{-# LANGUAGE StandaloneDeriving #-}
{-# LANGUAGE ExplicitForAll #-}

The type-level natural Nat from GHC.TypeLits is convenient, but not really suited for non-trivial type-level computations. It's best to use Peano numbers, a recursive way of defining numbers.

data Nat = Z | S Nat
  deriving Show

Z stands for the number zero. S n is the successor of n. For example:

  • S Z is the number 1, the successor of 0.
  • S (S Z) is the number 2, the successor of 1.
  • S (S (S Z)) is the number 3, the successor of 2.

Thanks to DataKinds, we can promote these constructors to types, and define some type synonyms for making it easier to work with these:

type N0 = Z
type N1 = S N0
type N2 = S N1
type N3 = S (S (S Z))
type N4 = S (S (S (S Z)))

With it, we can define a collection type that's indexed over its length:

data Vector (length :: Nat) a = DummyConstructor
  deriving Show

(Note: here I'm just using a dummy constructor. To learn more about length-indexed vectors, check out Matt Parson's Basic Type Level Programming in Haskell)

Now we can create vectors and know at compile-time how many elements they hold. E.g., Vector N3 Int is a vector that holds 3 integers.

When retrieving a specific element from this collection, we want to say "give me the nth element". But for this to be safe, n must be between 0 and the vector's length - 1.

To solve this problem, we can use a finine set, Fin n, that represents all numbers in the range [0, n).

  • Fin 0 is an uninhabited type
  • Fin 1 is inhabited by {0}
  • Fin 2 is inhabited by {0,1}
  • Fin 3 is inhabited by {0,1,2}, etc.
data Fin (n :: Nat) where
  FZ :: forall n. Fin (S n)
  FS :: forall n. Fin n -> Fin (S n)

deriving instance Show (Fin n)

The constructor FZ is polymorphic: it inhabits all finite sets Fin x, as long as x is the successor of another number n.

It's hard to explain how Fin works, so let's just see it in action.

λ> zero = FZ

λ> zero :: Fin N0   -- doesn't compile
λ> zero :: Fin N1   -- compiles
λ> zero :: Fin N2   -- compiles
λ> zero :: Fin N3   -- compiles

λ> one = FS FZ

λ> one :: Fin N0   -- doesn't compile
λ> one :: Fin N1   -- doesn't compile
λ> one :: Fin N2   -- compiles
λ> one :: Fin N3   -- compiles

λ> two = FS (FS FZ)

λ> two :: Fin N0   -- doesn't compile
λ> two :: Fin N1   -- doesn't compile
λ> two :: Fin N2   -- doesn't compile
λ> two :: Fin N3   -- compiles

Now, we can write a function that is guaranteed to safely fetch an element from a vector.

getElem :: Fin n -> Vector n a -> a
getElem = undefined

This function's type says: if a vector has n elements, we have to give it an index within the range [0, n).

-- for convenience
zero = FZ
one = FS zero
two = FS one
three = FS two
four = FS three

-- a vector with 2 ints
v = DummyConstructor :: Vector N2 Int

x = getElem one v    -- we can get the elem at index one
y = getElem three v  -- (doesn't compile) but we can't get the elemt at index three, because there's none

Richard Eisenberg talks more about Fin here.

