Last active
December 1, 2024 06:31
-
-
Save aavogt/25851e535cd917c7a9d9cdb2780bfe2e to your computer and use it in GitHub Desktop.
Defun
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{-# LANGUAGE FlexibleInstances #-} | |
{-# LANGUAGE LambdaCase #-} | |
{-# LANGUAGE MultiParamTypeClasses #-} | |
{-# LANGUAGE RankNTypes #-} | |
{-# LANGUAGE TemplateHaskell #-} | |
{-# LANGUAGE TupleSections #-} | |
{-# HLINT ignore "Functor law" #-} | |
{-# LANGUAGE TypeOperators #-} | |
{-# LANGUAGE ViewPatterns #-} | |
{-# OPTIONS_GHC -Wno-unrecognised-pragmas #-} | |
-- | | |
-- | |
-- = Defunctionalization QuasiQuoter | |
-- | |
-- You can't write a name once and have another function apply it at a | |
-- different types without explicitly writing a type signature. | |
-- One option is to go through a class like HList's 'ApplyAB'/'Apply' | |
-- which needs the whole type to be written out, and it pollutes the namespace. | |
-- | |
-- Or the dictionaries can be collected / stored in a gadt with something like | |
-- vinyl's 'DictOnly'/'rpureConstrained', and then pattern matching on the | |
-- DictOnly recovers the constraint. | |
-- | |
-- This quasiquote is an attempt to make both process easier, by calling | |
-- `hint` to get the type of the expression by itself, and using those inferred | |
-- constraints (cx) as @DictOnly \@cx@ or @instance cx => ApplyAB _ _ _@ | |
-- | |
-- According to mniip, this method assumes @(c1 => a1) ~ (c2 => a2)@ can be | |
-- solved by @c1 ~ c2@ but it doesn't always work. | |
-- | |
-- What is a minimal where this method fails? It won't work if there is some kind of recursion? | |
-- | |
-- [defun| failure z = hLength z |] | |
-- let r = hMap (Failure r) $ hBuild () () | |
-- | |
-- == HList | |
-- | |
-- A declaration quasiquote takes a prefix haskell function declaration, | |
-- which is interpreted slightly differently than usual. | |
-- The arguments left of `=` are stored in the data. | |
-- | |
-- > [defun| "xPlusK k z = \x -> k + read x / z"] | |
-- > r ;: HList '[Double, Float] | |
-- > r = hMap (XPlusK 2 3.4) $ hBuild "3.5" "2" | |
-- | |
-- > tests.hs:35:8-40: Splicing declarations | |
-- > Language.Haskell.TH.Quote.quoteDec | |
-- > defun " xPlusK k p = \\x -> read x + k " | |
-- > ======> | |
-- > instance (a_aDWy ~ String, b_aDWz ~ a, Fractional a, Read a) => | |
-- > ApplyAB XPlusK a_aDWy b_aDWz where | |
-- > applyAB (XPlusK k z) = \ x -> k + read x / z | |
-- > data XPlusK | |
-- > = XPlusK (forall {a}. (Fractional a, Read a) => a) (forall {a}. | |
-- > (Fractional a, Read a) => a | |
-- | |
-- == Vinyl | |
-- | |
-- As an expression and type for use with vinyl: | |
-- | |
-- > import Data.Vinyl | |
-- > mapRead2 = rzipWith [defun| \(Const x) -> Identity (read x) |] | |
-- > (rpureConstrained @[defun| |] @'[String, Int] (DictOnly @[defun| |])) | |
-- > xs | |
-- > | |
-- > tests.hs:24:28-62: Splicing expression | |
-- > Language.Haskell.TH.Quote.quoteExp | |
-- > defun " \\(Const x) -> Identity (read x) " | |
-- > ======> | |
-- > \ DictOnly -> \ (Const x) -> Identity (read x) | |
-- > tests.hs:25:31-33: Splicing type | |
-- > Language.Haskell.TH.Quote.quoteType defun " " ======> Read | |
-- > tests.hs:25:69-71: Splicing type | |
-- > Language.Haskell.TH.Quote.quoteType defun " " ======> Read | |
-- | |
-- Here the expression quote adds a @\DictOnly ->@, and subsequent type quotes | |
-- list the inferred constraints (Read in this case). | |
-- | |
-- TODO: equivalent of hMap which avoids repetition | |
module Defun where | |
import Control.Lens | |
import Control.Monad | |
import Data.Char (toUpper) | |
import Data.Data.Lens | |
import Data.Either | |
import Data.Foldable | |
import Data.Generics hiding (typeOf) | |
import Data.HList.CommonMain | |
import Data.IORef | |
import Data.List | |
import qualified Data.List.NonEmpty as NE | |
import Data.Maybe | |
import Data.Set (Set) | |
import qualified Data.Set as Set | |
import Data.Vinyl (DictOnly (..)) | |
import GHC.Stack | |
import Language.Haskell.Interpreter as Int | |
import Language.Haskell.Meta | |
import Language.Haskell.TH | |
import Language.Haskell.TH.Quote | |
import qualified Language.Haskell.TH.Syntax as TH | |
import System.IO.Unsafe (unsafePerformIO) | |
import Text.Show.Pretty (pPrint, ppShow) | |
internalDefunCxt :: IORef (Maybe Type) | |
{-# NOINLINE internalDefunCxt #-} | |
internalDefunCxt = unsafePerformIO $ newIORef Nothing | |
-- | defunctionalization quasiquote | |
defun = | |
QuasiQuoter | |
{ quoteExp = defunExp, | |
quoteType = \_ -> fmap (fromMaybe (error "missing defun exp")) $ runIO $ readIORef internalDefunCxt, | |
quotePat = error "not implemented", | |
quoteDec = defunDec | |
} | |
qRight :: (HasCallStack, Show a) => String -> Either a b -> Q b | |
qRight msg = either (fail . (msg ++) . ppShow) return | |
defunDec :: String -> Q [Dec] | |
defunDec str = do | |
tyStr <- qRight "cannot infer type with hint:" <=< runIO $ runInterpreter $ do | |
-- add standard imports | |
Int.set [languageExtensions := [Int.DataKinds, Int.GADTs, Int.QuasiQuotes, Int.TypeApplications, Int.NoMonomorphismRestriction]] | |
setImports ["Prelude", "Control.Applicative", "Data.Functor.Identity", "Data.HList.CommonMain"] | |
runStmt $ "let " ++ str | |
typeOf (head (words str)) -- `x+y = \z -> _` will fail, it could generate data (:+) = (:+) $(toRankN (typeOf x)) $(toRankN (typeOf y)) | |
[d] <- qRight "cannot parse to TH.Dec:" $ parseDecs str | |
let (capitalizeName -> n, nargs) = fundinfo d | |
ForallT _ cx (splitApps -> tokenArgTypes) <- qRight ("hint output: (" ++ tyStr ++ ") cannot be parsed to TH.Type:") $ parseType tyStr | |
tokenDec <- | |
[d|data XXX = XXX|] | |
<&> template | |
%~ ( \NormalC {} -> | |
NormalC | |
n | |
[ (Bang NoSourceUnpackedness NoSourceStrictness, toRankN cx ty) | |
| (i, ty) <- [1 .. nargs] `zip` tokenArgTypes | |
] | |
) | |
<&> template %~ \case | |
e | isPrefixOf "XXX" (show e) -> n | |
e -> e | |
tokenArgTypes <- return $ map return tokenArgTypes | |
bTy <- unsplitApps $ drop nargs tokenArgTypes | |
freshA <- newName "a" | |
freshB <- newName "b" | |
applyabInst <- | |
instanceD | |
( sequence | |
( [t|$(varT freshA) ~ $(tokenArgTypes !! nargs)|] | |
: [t|$(varT freshB) ~ $(unsplitApps (drop (nargs + 1) tokenArgTypes))|] | |
: map return cx | |
) | |
) | |
[t|ApplyAB $(conT n) $(varT freshA) $(varT freshB)|] | |
[ funD | |
'applyAB | |
[ clause | |
[conP n (funDpats d)] -- add the XPlusK pattern match | |
( return $ case d ^.. template of | |
[x] -> x | |
_ -> error "didn't parse as a single normalB" | |
) | |
[] | |
] | |
] | |
ty <- [d|typ = ($(TH.lift (fromRight undefined (parseType tyStr))), tyStr, d)|] | |
return (reverse (applyabInst : tokenDec)) | |
funDpats :: Dec -> [Q Pat] | |
funDpats (FunD _ [Clause p _ _]) = map return p | |
funDpats _ = [] | |
-- | Explicitly quantify the type variables | |
-- | |
-- When running defunDec str, these are the intermediate values: | |
-- | |
-- > str = "xPlusK = \\x -> read x + k" | |
-- > tyStr = (Read a, Show a) => a -> String -> a | |
-- > cx = (Read a, Show a) | |
-- > tokenArgTypes = [a, String, a] | |
-- | |
-- > toRankN cx a = forall a. (Read a, Show a) => a | |
-- > toRankN cx String = String | |
toRankN :: Cxt -> Type -> Type | |
toRankN cx ty | |
| Set.null varsInTy = ty | |
| otherwise = ForallT [PlainTV v InferredSpec | v <- toList varsInRelCx] cx' ty | |
where | |
varsInTy = everythingPS varTP ty | |
relevant = has (template . varTP . filtered (`Set.member` varsInTy)) | |
cx' = filter relevant cx | |
varsInRelCx = everythingPS varTP cx' | |
everythingPS :: (Typeable t, Ord b) => Prism' t b -> GenericQ (Set b) | |
everythingPS p = everything Set.union (mkQ Set.empty (\x -> maybe Set.empty Set.singleton $ x ^? p)) | |
-- | | |
-- | |
-- >>> VarT (mkName "a") ^? varTP | |
-- Just a | |
-- | |
-- >>> ConT (mkName "a") ^? varTP | |
-- Nothing | |
varTP :: Prism' Type Name | |
varTP = prism' VarT $ \case | |
VarT n -> Just n | |
_ -> Nothing | |
-- | Turn a -> (b -> c) -> d into [a, b->c, d] | |
-- | |
-- | |
-- >>> let a_1 = mkName "a1" in splitAppsRev (AppT (AppT ArrowT (VarT a_1)) (AppT (AppT ArrowT (VarT a_1)) (VarT a_1))) | |
-- [VarT a1,VarT a1,VarT a1] | |
-- | |
-- >>> let a = mkName "a"; b = mkName "b" in splitAppsRev (AppT (AppT ArrowT (VarT a)) (VarT b)) | |
-- [VarT a,VarT b] | |
-- | |
-- | |
-- >>> let [a,b,c,d] = map mkName (words "a b c d") in splitAppsRev (AppT (AppT ArrowT (VarT a)) (AppT (AppT ArrowT (VarT b)) (AppT (AppT ArrowT (VarT c)) (VarT d)))) | |
-- [VarT a,VarT b,VarT c,VarT d] | |
-- | |
-- | |
-- >>> let [a,b,c] = map mkName (words "a b c") in splitAppsRev (AppT (AppT ArrowT (AppT (AppT ArrowT (VarT a)) (VarT b))) (VarT c)) | |
-- [AppT (AppT ArrowT (VarT a)) (VarT b),VarT c] | |
splitApps (AppT (AppT ArrowT t1) t2) = t1 : splitApps t2 | |
splitApps ArrowT = [] | |
splitApps x = [x] | |
unsplitApps :: [TypeQ] -> TypeQ | |
unsplitApps = foldr1 (\a b -> [t|$a -> $b|]) | |
lookupCxt :: Cxt -> Name -> [Type] | |
lookupCxt xs n = [x | x <- xs, VarT n `elem` splitApps x] | |
-- above fundargs = 1, gives the number of values to store in the XPlusK | |
fundinfo (FunD n (Clause ps _ _ : _)) = (n, length ps) | |
fundinfo (ValD (VarP x) y _) = (x, 0) | |
fundinfo _ = error "expected FunD with one clause" | |
capitalizeName :: Name -> Name | |
capitalizeName n = mkName $ case show n of | |
x : xs -> toUpper x : xs | |
_ -> [] | |
defunExp str = do | |
mstr@(~(Right ty)) <- runIO $ runInterpreter $ do | |
-- add standard imports | |
Int.set [languageExtensions := [Int.DataKinds, Int.GADTs, Int.QuasiQuotes, Int.TypeApplications, Int.NoMonomorphismRestriction]] | |
setImports ["Prelude", "Control.Applicative", "Data.Functor.Identity"] | |
runStmt $ "let defunExprResult = " ++ str | |
typeOf "defunExprResult" | |
case mstr of | |
Left err -> fail $ show ("ran interpreter", show err) | |
Right _ -> return () | |
case parseType (dropForall ty) of | |
Left err -> fail $ show ("parsetype", ty, err) | |
Right ty -> do | |
runIO $ writeIORef internalDefunCxt $ Just $ splitContext ty | |
case parseExp str of | |
Left err -> fail $ show ("parseExp", str, err) | |
Right e -> [|\DictOnly -> $(return e)|] | |
dropForall :: String -> String | |
dropForall (stripPrefix "forall" -> Just rest) = drop 1 $ dropWhile (/= '.') rest | |
dropForall x = x | |
splitContext :: Type -> Type | |
splitContext (ForallT _ cxt _) = tupT_ $ map dropArg $ filter notFunctor cxt | |
splitContext t = t | |
notFunctor (ConT f `AppT` _) = show f /= "Functor" | |
notFunctor _ = True | |
tupT :: [TypeQ] -> TypeQ | |
tupT [] = [t|()|] | |
tupT [t] = t | |
tupT ts = foldl appT (tupleT $ length ts) ts | |
tupT_ :: [Type] -> Type | |
tupT_ [] = ConT ''() | |
tupT_ [t] = t | |
tupT_ ts = foldl AppT (TupleT $ length ts) ts | |
dropArg :: Type -> Type | |
dropArg (AppT t _) = t | |
dropArg t = t |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
dependencies: | |
- template-haskell | |
- vinyl | |
- hint | |
- base | |
- haskell-src-meta | |
- HList | |
- pretty-show | |
- lens | |
- containers | |
- syb | |
library: | |
source-dirs: | |
- . | |
tests: | |
tests: | |
source-dirs: | |
- . | |
main: | |
tests.hs | |
dependencies: | |
- defun |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{-# LANGUAGE DataKinds #-} | |
{-# LANGUAGE TypeOperators #-} | |
{-# LANGUAGE FlexibleContexts #-} | |
{-# LANGUAGE GADTs #-} | |
{-# LANGUAGE PartialTypeSignatures #-} | |
{-# LANGUAGE QuasiQuotes #-} | |
{-# LANGUAGE TypeApplications #-} | |
{-# LANGUAGE UndecidableInstances #-} | |
{-# OPTIONS_GHC -ddump-splices #-} | |
{-# LANGUAGE RankNTypes #-} | |
{-# LANGUAGE FlexibleInstances #-} | |
{-# LANGUAGE MultiParamTypeClasses #-} | |
import Control.Applicative | |
import Data.Functor.Identity | |
import Data.HList.CommonMain | |
import Data.Vinyl hiding (HList) | |
import Defun | |
xs = Const "\"x\"" :& Const "3" :& RNil | |
rpc = rpureConstrained @Read @'[String, Int] (DictOnly @Read) | |
mapRead = | |
rzipWith | |
(\DictOnly (Const x) -> Identity (read x)) | |
rpc | |
xs | |
mapRead2 = rzipWith [defun| \(Const x) -> Identity (read x) |] | |
(rpureConstrained @[defun| |] @'[String, Int] (DictOnly @[defun| |])) | |
xs | |
main = do | |
print mapRead | |
print mapRead2 | |
print mapRead3 | |
[defun| xPlusK k z = \x -> k + read x / z |] | |
mapRead3 :: HList '[Double, Float] | |
mapRead3 = hMap (XPlusK 2 3.4) $ hBuild "3.5" "2" | |
-- [defun| recurse e = \f -> hLength e |] | |
-- data Recurse | |
-- = Recurse (forall {l} {n}. | |
-- SameLength' (HReplicateR n ()) l => HList l) | |
-- instance (a_aHwS ~ p, | |
-- b_aHwT ~ Proxy n, | |
-- SameLength' (HReplicateR n ()) l, | |
-- HLengthEq1 l n, | |
-- HLengthEq2 l n) => | |
-- ApplyAB Recurse a_aHwS b_aHwT where | |
-- applyAB (Recurse e) = \ f -> hLength e | |
-- tests.hs:44:8-38: error: [GHC-39999] | |
-- • Could not deduce ‘SameLength' (HReplicateR n ()) l0’ | |
-- from the context: (a ~ p, b ~ Proxy n, | |
-- SameLength' (HReplicateR n ()) l, HLengthEq1 l n, HLengthEq2 l n) | |
-- | |
-- this doesn't typecheck. How can it? |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment