Skip to content

Instantly share code, notes, and snippets.

@aavogt
Last active December 1, 2024 06:31
Show Gist options
  • Save aavogt/25851e535cd917c7a9d9cdb2780bfe2e to your computer and use it in GitHub Desktop.
Save aavogt/25851e535cd917c7a9d9cdb2780bfe2e to your computer and use it in GitHub Desktop.
Defun
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE TemplateHaskell #-}
{-# LANGUAGE TupleSections #-}
{-# HLINT ignore "Functor law" #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE ViewPatterns #-}
{-# OPTIONS_GHC -Wno-unrecognised-pragmas #-}
-- |
--
-- = Defunctionalization QuasiQuoter
--
-- You can't write a name once and have another function apply it at a
-- different types without explicitly writing a type signature.
-- One option is to go through a class like HList's 'ApplyAB'/'Apply'
-- which needs the whole type to be written out, and it pollutes the namespace.
--
-- Or the dictionaries can be collected / stored in a gadt with something like
-- vinyl's 'DictOnly'/'rpureConstrained', and then pattern matching on the
-- DictOnly recovers the constraint.
--
-- This quasiquote is an attempt to make both process easier, by calling
-- `hint` to get the type of the expression by itself, and using those inferred
-- constraints (cx) as @DictOnly \@cx@ or @instance cx => ApplyAB _ _ _@
--
-- According to mniip, this method assumes @(c1 => a1) ~ (c2 => a2)@ can be
-- solved by @c1 ~ c2@ but it doesn't always work.
--
-- What is a minimal where this method fails? It won't work if there is some kind of recursion?
--
-- [defun| failure z = hLength z |]
-- let r = hMap (Failure r) $ hBuild () ()
--
-- == HList
--
-- A declaration quasiquote takes a prefix haskell function declaration,
-- which is interpreted slightly differently than usual.
-- The arguments left of `=` are stored in the data.
--
-- > [defun| "xPlusK k z = \x -> k + read x / z"]
-- > r ;: HList '[Double, Float]
-- > r = hMap (XPlusK 2 3.4) $ hBuild "3.5" "2"
--
-- > tests.hs:35:8-40: Splicing declarations
-- > Language.Haskell.TH.Quote.quoteDec
-- > defun " xPlusK k p = \\x -> read x + k "
-- > ======>
-- > instance (a_aDWy ~ String, b_aDWz ~ a, Fractional a, Read a) =>
-- > ApplyAB XPlusK a_aDWy b_aDWz where
-- > applyAB (XPlusK k z) = \ x -> k + read x / z
-- > data XPlusK
-- > = XPlusK (forall {a}. (Fractional a, Read a) => a) (forall {a}.
-- > (Fractional a, Read a) => a
--
-- == Vinyl
--
-- As an expression and type for use with vinyl:
--
-- > import Data.Vinyl
-- > mapRead2 = rzipWith [defun| \(Const x) -> Identity (read x) |]
-- > (rpureConstrained @[defun| |] @'[String, Int] (DictOnly @[defun| |]))
-- > xs
-- >
-- > tests.hs:24:28-62: Splicing expression
-- > Language.Haskell.TH.Quote.quoteExp
-- > defun " \\(Const x) -> Identity (read x) "
-- > ======>
-- > \ DictOnly -> \ (Const x) -> Identity (read x)
-- > tests.hs:25:31-33: Splicing type
-- > Language.Haskell.TH.Quote.quoteType defun " " ======> Read
-- > tests.hs:25:69-71: Splicing type
-- > Language.Haskell.TH.Quote.quoteType defun " " ======> Read
--
-- Here the expression quote adds a @\DictOnly ->@, and subsequent type quotes
-- list the inferred constraints (Read in this case).
--
-- TODO: equivalent of hMap which avoids repetition
module Defun where
import Control.Lens
import Control.Monad
import Data.Char (toUpper)
import Data.Data.Lens
import Data.Either
import Data.Foldable
import Data.Generics hiding (typeOf)
import Data.HList.CommonMain
import Data.IORef
import Data.List
import qualified Data.List.NonEmpty as NE
import Data.Maybe
import Data.Set (Set)
import qualified Data.Set as Set
import Data.Vinyl (DictOnly (..))
import GHC.Stack
import Language.Haskell.Interpreter as Int
import Language.Haskell.Meta
import Language.Haskell.TH
import Language.Haskell.TH.Quote
import qualified Language.Haskell.TH.Syntax as TH
import System.IO.Unsafe (unsafePerformIO)
import Text.Show.Pretty (pPrint, ppShow)
internalDefunCxt :: IORef (Maybe Type)
{-# NOINLINE internalDefunCxt #-}
internalDefunCxt = unsafePerformIO $ newIORef Nothing
-- | defunctionalization quasiquote
defun =
QuasiQuoter
{ quoteExp = defunExp,
quoteType = \_ -> fmap (fromMaybe (error "missing defun exp")) $ runIO $ readIORef internalDefunCxt,
quotePat = error "not implemented",
quoteDec = defunDec
}
qRight :: (HasCallStack, Show a) => String -> Either a b -> Q b
qRight msg = either (fail . (msg ++) . ppShow) return
defunDec :: String -> Q [Dec]
defunDec str = do
tyStr <- qRight "cannot infer type with hint:" <=< runIO $ runInterpreter $ do
-- add standard imports
Int.set [languageExtensions := [Int.DataKinds, Int.GADTs, Int.QuasiQuotes, Int.TypeApplications, Int.NoMonomorphismRestriction]]
setImports ["Prelude", "Control.Applicative", "Data.Functor.Identity", "Data.HList.CommonMain"]
runStmt $ "let " ++ str
typeOf (head (words str)) -- `x+y = \z -> _` will fail, it could generate data (:+) = (:+) $(toRankN (typeOf x)) $(toRankN (typeOf y))
[d] <- qRight "cannot parse to TH.Dec:" $ parseDecs str
let (capitalizeName -> n, nargs) = fundinfo d
ForallT _ cx (splitApps -> tokenArgTypes) <- qRight ("hint output: (" ++ tyStr ++ ") cannot be parsed to TH.Type:") $ parseType tyStr
tokenDec <-
[d|data XXX = XXX|]
<&> template
%~ ( \NormalC {} ->
NormalC
n
[ (Bang NoSourceUnpackedness NoSourceStrictness, toRankN cx ty)
| (i, ty) <- [1 .. nargs] `zip` tokenArgTypes
]
)
<&> template %~ \case
e | isPrefixOf "XXX" (show e) -> n
e -> e
tokenArgTypes <- return $ map return tokenArgTypes
bTy <- unsplitApps $ drop nargs tokenArgTypes
freshA <- newName "a"
freshB <- newName "b"
applyabInst <-
instanceD
( sequence
( [t|$(varT freshA) ~ $(tokenArgTypes !! nargs)|]
: [t|$(varT freshB) ~ $(unsplitApps (drop (nargs + 1) tokenArgTypes))|]
: map return cx
)
)
[t|ApplyAB $(conT n) $(varT freshA) $(varT freshB)|]
[ funD
'applyAB
[ clause
[conP n (funDpats d)] -- add the XPlusK pattern match
( return $ case d ^.. template of
[x] -> x
_ -> error "didn't parse as a single normalB"
)
[]
]
]
ty <- [d|typ = ($(TH.lift (fromRight undefined (parseType tyStr))), tyStr, d)|]
return (reverse (applyabInst : tokenDec))
funDpats :: Dec -> [Q Pat]
funDpats (FunD _ [Clause p _ _]) = map return p
funDpats _ = []
-- | Explicitly quantify the type variables
--
-- When running defunDec str, these are the intermediate values:
--
-- > str = "xPlusK = \\x -> read x + k"
-- > tyStr = (Read a, Show a) => a -> String -> a
-- > cx = (Read a, Show a)
-- > tokenArgTypes = [a, String, a]
--
-- > toRankN cx a = forall a. (Read a, Show a) => a
-- > toRankN cx String = String
toRankN :: Cxt -> Type -> Type
toRankN cx ty
| Set.null varsInTy = ty
| otherwise = ForallT [PlainTV v InferredSpec | v <- toList varsInRelCx] cx' ty
where
varsInTy = everythingPS varTP ty
relevant = has (template . varTP . filtered (`Set.member` varsInTy))
cx' = filter relevant cx
varsInRelCx = everythingPS varTP cx'
everythingPS :: (Typeable t, Ord b) => Prism' t b -> GenericQ (Set b)
everythingPS p = everything Set.union (mkQ Set.empty (\x -> maybe Set.empty Set.singleton $ x ^? p))
-- |
--
-- >>> VarT (mkName "a") ^? varTP
-- Just a
--
-- >>> ConT (mkName "a") ^? varTP
-- Nothing
varTP :: Prism' Type Name
varTP = prism' VarT $ \case
VarT n -> Just n
_ -> Nothing
-- | Turn a -> (b -> c) -> d into [a, b->c, d]
--
--
-- >>> let a_1 = mkName "a1" in splitAppsRev (AppT (AppT ArrowT (VarT a_1)) (AppT (AppT ArrowT (VarT a_1)) (VarT a_1)))
-- [VarT a1,VarT a1,VarT a1]
--
-- >>> let a = mkName "a"; b = mkName "b" in splitAppsRev (AppT (AppT ArrowT (VarT a)) (VarT b))
-- [VarT a,VarT b]
--
--
-- >>> let [a,b,c,d] = map mkName (words "a b c d") in splitAppsRev (AppT (AppT ArrowT (VarT a)) (AppT (AppT ArrowT (VarT b)) (AppT (AppT ArrowT (VarT c)) (VarT d))))
-- [VarT a,VarT b,VarT c,VarT d]
--
--
-- >>> let [a,b,c] = map mkName (words "a b c") in splitAppsRev (AppT (AppT ArrowT (AppT (AppT ArrowT (VarT a)) (VarT b))) (VarT c))
-- [AppT (AppT ArrowT (VarT a)) (VarT b),VarT c]
splitApps (AppT (AppT ArrowT t1) t2) = t1 : splitApps t2
splitApps ArrowT = []
splitApps x = [x]
unsplitApps :: [TypeQ] -> TypeQ
unsplitApps = foldr1 (\a b -> [t|$a -> $b|])
lookupCxt :: Cxt -> Name -> [Type]
lookupCxt xs n = [x | x <- xs, VarT n `elem` splitApps x]
-- above fundargs = 1, gives the number of values to store in the XPlusK
fundinfo (FunD n (Clause ps _ _ : _)) = (n, length ps)
fundinfo (ValD (VarP x) y _) = (x, 0)
fundinfo _ = error "expected FunD with one clause"
capitalizeName :: Name -> Name
capitalizeName n = mkName $ case show n of
x : xs -> toUpper x : xs
_ -> []
defunExp str = do
mstr@(~(Right ty)) <- runIO $ runInterpreter $ do
-- add standard imports
Int.set [languageExtensions := [Int.DataKinds, Int.GADTs, Int.QuasiQuotes, Int.TypeApplications, Int.NoMonomorphismRestriction]]
setImports ["Prelude", "Control.Applicative", "Data.Functor.Identity"]
runStmt $ "let defunExprResult = " ++ str
typeOf "defunExprResult"
case mstr of
Left err -> fail $ show ("ran interpreter", show err)
Right _ -> return ()
case parseType (dropForall ty) of
Left err -> fail $ show ("parsetype", ty, err)
Right ty -> do
runIO $ writeIORef internalDefunCxt $ Just $ splitContext ty
case parseExp str of
Left err -> fail $ show ("parseExp", str, err)
Right e -> [|\DictOnly -> $(return e)|]
dropForall :: String -> String
dropForall (stripPrefix "forall" -> Just rest) = drop 1 $ dropWhile (/= '.') rest
dropForall x = x
splitContext :: Type -> Type
splitContext (ForallT _ cxt _) = tupT_ $ map dropArg $ filter notFunctor cxt
splitContext t = t
notFunctor (ConT f `AppT` _) = show f /= "Functor"
notFunctor _ = True
tupT :: [TypeQ] -> TypeQ
tupT [] = [t|()|]
tupT [t] = t
tupT ts = foldl appT (tupleT $ length ts) ts
tupT_ :: [Type] -> Type
tupT_ [] = ConT ''()
tupT_ [t] = t
tupT_ ts = foldl AppT (TupleT $ length ts) ts
dropArg :: Type -> Type
dropArg (AppT t _) = t
dropArg t = t
dependencies:
- template-haskell
- vinyl
- hint
- base
- haskell-src-meta
- HList
- pretty-show
- lens
- containers
- syb
library:
source-dirs:
- .
tests:
tests:
source-dirs:
- .
main:
tests.hs
dependencies:
- defun
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE PartialTypeSignatures #-}
{-# LANGUAGE QuasiQuotes #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE UndecidableInstances #-}
{-# OPTIONS_GHC -ddump-splices #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE MultiParamTypeClasses #-}
import Control.Applicative
import Data.Functor.Identity
import Data.HList.CommonMain
import Data.Vinyl hiding (HList)
import Defun
xs = Const "\"x\"" :& Const "3" :& RNil
rpc = rpureConstrained @Read @'[String, Int] (DictOnly @Read)
mapRead =
rzipWith
(\DictOnly (Const x) -> Identity (read x))
rpc
xs
mapRead2 = rzipWith [defun| \(Const x) -> Identity (read x) |]
(rpureConstrained @[defun| |] @'[String, Int] (DictOnly @[defun| |]))
xs
main = do
print mapRead
print mapRead2
print mapRead3
[defun| xPlusK k z = \x -> k + read x / z |]
mapRead3 :: HList '[Double, Float]
mapRead3 = hMap (XPlusK 2 3.4) $ hBuild "3.5" "2"
-- [defun| recurse e = \f -> hLength e |]
-- data Recurse
-- = Recurse (forall {l} {n}.
-- SameLength' (HReplicateR n ()) l => HList l)
-- instance (a_aHwS ~ p,
-- b_aHwT ~ Proxy n,
-- SameLength' (HReplicateR n ()) l,
-- HLengthEq1 l n,
-- HLengthEq2 l n) =>
-- ApplyAB Recurse a_aHwS b_aHwT where
-- applyAB (Recurse e) = \ f -> hLength e
-- tests.hs:44:8-38: error: [GHC-39999]
-- • Could not deduce ‘SameLength' (HReplicateR n ()) l0’
-- from the context: (a ~ p, b ~ Proxy n,
-- SameLength' (HReplicateR n ()) l, HLengthEq1 l n, HLengthEq2 l n)
--
-- this doesn't typecheck. How can it?
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment