Created
June 9, 2017 17:54
-
-
Save edsko/b33428a434ede28cd08eef9dd08fe035 to your computer and use it in GitHub Desktop.
Run-time type information in Haskell
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{-# LANGUAGE FlexibleContexts #-} | |
{-# LANGUAGE FlexibleInstances #-} | |
{-# LANGUAGE DataKinds #-} | |
{-# LANGUAGE EmptyCase #-} | |
{-# LANGUAGE GADTs #-} | |
{-# LANGUAGE KindSignatures #-} | |
{-# LANGUAGE MultiParamTypeClasses #-} | |
{-# LANGUAGE PolyKinds #-} | |
{-# LANGUAGE RankNTypes #-} | |
{-# LANGUAGE ScopedTypeVariables #-} | |
{-# LANGUAGE StandaloneDeriving #-} | |
{-# LANGUAGE TypeFamilies #-} | |
{-# LANGUAGE TypeOperators #-} | |
{-# LANGUAGE UndecidableInstances #-} | |
{-# LANGUAGE PatternSynonyms #-} | |
{-# OPTIONS_GHC -Wall #-} | |
module RTTI ( | |
main | |
, decodeVals | |
-- * RTTI infrastructure | |
, RTTI(..) | |
, HasRTTI(..) | |
, Implicit(..) | |
, Explicit(..) | |
, implicit | |
, explicit | |
, reflectRTTI | |
) where | |
import GHC.Exts | |
import Data.Binary | |
import Data.ByteString.Lazy (ByteString) | |
import Unsafe.Coerce (unsafeCoerce) | |
{------------------------------------------------------------------------------- | |
We need some RTTI in order to serialize | |
-------------------------------------------------------------------------------} | |
data family RTTI (f :: k -> *) :: (k -> *) | |
class HasRTTI f a where | |
rtti :: RTTI f a | |
newtype Implicit f a b = Implicit { unImplicit :: HasRTTI f a => b } | |
newtype Explicit f a b = Explicit { unExplicit :: RTTI f a -> b } | |
implicit :: Explicit f a b -> Implicit f a b | |
implicit (Explicit f) = Implicit $ f rtti | |
explicit :: Implicit f a b -> Explicit f a b | |
explicit = unsafeCoerce | |
reflectRTTI :: RTTI f a -> (HasRTTI f a => b) -> b | |
reflectRTTI r k = unExplicit (explicit (Implicit k)) r | |
{------------------------------------------------------------------------------- | |
Simple example | |
-------------------------------------------------------------------------------} | |
data Val :: * -> * where | |
VI :: Int -> Val Int | |
VD :: Double -> Val Double | |
deriving instance Show (Val a) | |
{------------------------------------------------------------------------------- | |
Failed attempt at serialization | |
-------------------------------------------------------------------------------} | |
{- | |
instance Binary (Val a) where | |
put (VI i) = putWord8 0 >> put i | |
put (VD d) = putWord8 1 >> put d | |
get = do | |
tag <- getWord8 | |
case tag of | |
0 -> VI <$> get -- Couldn't match type ‘a’ with ‘Int’ | |
1 -> VD <$> get -- Couldn't match type ‘a’ with ‘Double’ | |
_ -> error "invalid tag" | |
-} | |
{- | |
instance Binary (Val Int) where | |
put (VI i) = put i | |
get = VI <$> get | |
instance Binary (Val Double) where | |
put (VD d) = put d | |
get = VD <$> get | |
encodeVal :: Binary (Val a) => Val a -> ByteString | |
encodeVal = encode | |
-} | |
{------------------------------------------------------------------------------- | |
Taking advantage of RTTI | |
-------------------------------------------------------------------------------} | |
data instance RTTI Val a where | |
RttiValInt :: RTTI Val Int | |
RttiValDouble :: RTTI Val Double | |
instance HasRTTI Val Int where rtti = RttiValInt | |
instance HasRTTI Val Double where rtti = RttiValDouble | |
putVal :: Val a -> Put | |
putVal (VI i) = put i | |
putVal (VD d) = put d | |
getVal :: RTTI Val a -> Get (Val a) | |
getVal RttiValInt = VI <$> get | |
getVal RttiValDouble = VD <$> get | |
instance HasRTTI Val a => Binary (Val a) where | |
put = putVal | |
get = getVal rtti | |
exampleRoundtripVal :: IO () | |
exampleRoundtripVal = do | |
print $ roundtrip $ VI 12 | |
print $ roundtrip $ VD 34.56 | |
roundtrip :: Binary a => a -> a | |
roundtrip = decode . encode | |
{------------------------------------------------------------------------------- | |
Another example: lists | |
-------------------------------------------------------------------------------} | |
data NP (f :: k -> *) (xs :: [k]) where | |
Nil :: NP f '[] | |
(:*) :: f x -> NP f xs -> NP f (x ': xs) | |
infixr 5 :* | |
deriving instance (All Show f xs) => Show (NP f xs) | |
type family All p f xs :: Constraint where | |
All p f '[] = () | |
All p f (x ': xs) = (p (f x), All p f xs) | |
{------------------------------------------------------------------------------- | |
RTTI for lists | |
-------------------------------------------------------------------------------} | |
data instance RTTI (NP f) xs where | |
RttiNpNil :: RTTI (NP f) '[] | |
RttiNpCons :: (HasRTTI f x, HasRTTI (NP f) xs) | |
=> RTTI (NP f) (x ': xs) | |
instance HasRTTI (NP f) '[] where | |
rtti = RttiNpNil | |
instance (HasRTTI f x, HasRTTI (NP f) xs) | |
=> HasRTTI (NP f) (x ': xs) where | |
rtti = RttiNpCons | |
{------------------------------------------------------------------------------- | |
Encoding lists | |
-------------------------------------------------------------------------------} | |
putNP :: All Binary f xs => NP f xs -> Put | |
putNP Nil = return () | |
putNP (x :* xs) = put x >> putNP xs | |
getNP :: All Binary f xs => RTTI (NP f) xs -> Get (NP f xs) | |
getNP RttiNpNil = return Nil | |
getNP RttiNpCons = (:*) <$> get <*> getNP rtti | |
instance {-# OVERLAPPABLE #-} (All Binary f xs, HasRTTI (NP f) xs) | |
=> Binary (NP f xs) where | |
put = putNP | |
get = getNP rtti | |
exampleBinaryLists :: IO () | |
exampleBinaryLists = do | |
print $ roundtrip $ VI 12 :* VD 34.56 :* Nil | |
print $ roundtrip $ Mod 2 :* Sqrt :* Nil | |
{- | |
decodeVals :: (HasRTTI (NP Val) xs, All Binary Val xs) | |
=> ByteString -> NP Val xs | |
decodeVals = decode | |
-} | |
{------------------------------------------------------------------------------- | |
But we can give another instance for Val specifically | |
The HasRTTI requirement on the elements is not strictly necessary, but allows | |
us to give a better definition for NP Val. Note that this is a 'safe' use | |
of overlappable instances: they are entirely compatible, just generate fewer | |
type constraints. | |
-------------------------------------------------------------------------------} | |
putNpVal :: NP Val xs -> Put | |
putNpVal Nil = return () | |
putNpVal (x :* xs) = putVal x >> putNpVal xs | |
getNpVal :: RTTI (NP Val) xs -> Get (NP Val xs) | |
getNpVal RttiNpNil = return Nil | |
getNpVal RttiNpCons = (:*) <$> get <*> getNpVal rtti | |
instance {-# OVERLAPPING #-} HasRTTI (NP Val) xs | |
=> Binary (NP Val xs) where | |
put = putNpVal | |
get = getNpVal rtti | |
decodeVals :: HasRTTI (NP Val) xs => ByteString -> NP Val xs | |
decodeVals = decode | |
{------------------------------------------------------------------------------- | |
More involved example | |
-------------------------------------------------------------------------------} | |
data Fn :: (*,*) -> * where | |
Exp :: Fn '(Double, Double) | |
Sqrt :: Fn '(Double, Double) | |
Mod :: Int -> Fn '(Int, Int) | |
Round :: Fn '(Double, Int) | |
Comp :: (HasRTTI Fn '(b,c), HasRTTI Fn '(a,b)) | |
=> Fn '(b,c) -> Fn '(a,b) -> Fn '(a,c) | |
deriving instance Show (Fn a) | |
eval :: Fn '(a,b) -> a -> b | |
eval Exp = exp | |
eval Sqrt = sqrt | |
eval (Mod m) = (`mod` m) | |
eval Round = round | |
eval (g `Comp` f) = eval g . eval f | |
exampleEval :: IO () | |
exampleEval = do | |
print $ eval Exp 1 | |
print $ eval Sqrt 9 | |
print $ eval (Mod 3) 11 | |
print $ eval (Round `Comp` Exp `Comp` Exp) 1 | |
data instance RTTI Fn ab where | |
RttiFnDD :: RTTI Fn '(Double, Double) | |
RttiFnII :: RTTI Fn '(Int, Int) | |
RttiFnDI :: RTTI Fn '(Double, Int) | |
instance HasRTTI Fn '(Double, Double) where rtti = RttiFnDD | |
instance HasRTTI Fn '(Int, Int) where rtti = RttiFnII | |
instance HasRTTI Fn '(Double, Int) where rtti = RttiFnDI | |
{------------------------------------------------------------------------------- | |
Serialization of Fn | |
-------------------------------------------------------------------------------} | |
data RttiComp :: (*,*) -> * where | |
RttiComp :: RTTI Fn '(b,c) -> RTTI Fn '(a,b) -> RttiComp '(a,c) | |
putRttiComp :: RTTI Fn '(a,c) -> RttiComp '(a,c) -> Put | |
putRttiComp rac (RttiComp rbc rab) = go rac rbc rab | |
where | |
go :: RTTI Fn '(a,c) -> RTTI Fn '(b,c) -> RTTI Fn '(a,b) -> Put | |
go RttiFnDD RttiFnDD RttiFnDD = return () | |
go RttiFnII RttiFnII RttiFnII = return () | |
go RttiFnII RttiFnDI rAB = case rAB of {} | |
go RttiFnDI RttiFnII RttiFnDI = putWord8 0 | |
go RttiFnDI RttiFnDI RttiFnDD = putWord8 1 | |
getRttiComp :: RTTI Fn '(a,c) -> Get (RttiComp '(a,c)) | |
getRttiComp RttiFnDD = return $ RttiComp RttiFnDD RttiFnDD | |
getRttiComp RttiFnII = return $ RttiComp RttiFnII RttiFnII | |
getRttiComp RttiFnDI = do | |
tag <- getWord8 | |
case tag of | |
0 -> return $ RttiComp RttiFnII RttiFnDI | |
1 -> return $ RttiComp RttiFnDI RttiFnDD | |
_ -> fail "invalid tag" | |
putAct :: RTTI Fn a -> Fn a -> Put | |
putAct = go | |
where | |
go :: RTTI Fn a -> Fn a -> Put | |
go r@RttiFnDD fn = | |
case fn of | |
Exp -> putWord8 0 | |
Sqrt -> putWord8 1 | |
Comp g f -> putWord8 255 >> goComp r (rtti, g) (rtti, f) | |
go r@RttiFnII fn = | |
case fn of | |
Mod m -> putWord8 0 >> put m | |
Comp g f -> putWord8 255 >> goComp r (rtti, g) (rtti, f) | |
go r@RttiFnDI fn = | |
case fn of | |
Round -> putWord8 0 | |
Comp g f -> putWord8 255 >> goComp r (rtti, g) (rtti, f) | |
goComp :: RTTI Fn '(a,c) | |
-> (RTTI Fn '(b,c), Fn '(b,c)) | |
-> (RTTI Fn '(a,b), Fn '(a,b)) | |
-> Put | |
goComp rAC (rBC, g) (rAB, f) = do | |
putRttiComp rAC (RttiComp rBC rAB) | |
go rBC g | |
go rAB f | |
getAct :: RTTI Fn a -> Get (Fn a) | |
getAct = go | |
where | |
go :: RTTI Fn a -> Get (Fn a) | |
go r@RttiFnDD = do | |
tag <- getWord8 | |
case tag of | |
0 -> return Exp | |
1 -> return Sqrt | |
255 -> goComp r | |
_ -> error "invalid tag" | |
go r@RttiFnII = do | |
tag <- getWord8 | |
case tag of | |
0 -> Mod <$> get | |
255 -> goComp r | |
_ -> error "invalid tag" | |
go r@RttiFnDI = do | |
tag <- getWord8 | |
case tag of | |
0 -> return Round | |
255 -> goComp r | |
_ -> error "invalid tag" | |
goComp :: RTTI Fn '(a,c) -> Get (Fn '(a,c)) | |
goComp rAC = do | |
RttiComp rBC rAB <- getRttiComp rAC | |
reflectRTTI rBC $ reflectRTTI rAB $ | |
Comp <$> go rBC <*> go rAB | |
instance HasRTTI Fn a => Binary (Fn a) where | |
put = putAct rtti | |
get = getAct rtti | |
exampleBinaryAct :: IO () | |
exampleBinaryAct = do | |
print $ roundtrip Exp | |
print $ roundtrip Sqrt | |
print $ roundtrip (Mod 3) | |
print $ roundtrip (Round `Comp` Exp `Comp` Exp) | |
{------------------------------------------------------------------------------- | |
Run the examples | |
-------------------------------------------------------------------------------} | |
main :: IO () | |
main = do | |
putStrLn "** Running examples" | |
exampleRoundtripVal | |
exampleEval | |
exampleBinaryAct | |
exampleBinaryLists |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment