Skip to content

Instantly share code, notes, and snippets.

@charles-cooper
Last active December 10, 2015 22:17
Show Gist options
  • Save charles-cooper/379d1702df2051b9c60c to your computer and use it in GitHub Desktop.
Save charles-cooper/379d1702df2051b9c60c to your computer and use it in GitHub Desktop.
RFC: dsl for writing C
import Language.C.DSL
import Language.Format
main = prettyPrint $ do
include "stdio.h"
includeLocal "foo.h"
intptr_t <- typedef (ptr int) "intptr_t"
x <- int "x"
y <- double "y"
z <- ptr intptr_t "z"
w <- carray int "w" 21
ptrToConstInt <- ptr (cconst int) "a"
constPtrToInt <- cconst (ptr int) "b"
f <- fun int "f" [int "x1"] $ do
cswitch y $ do
ccaseBreak (boolCompl x) $ do
cif (bitCompl "x1") $ do
cwhile (insertComment "Lorem ipsum" (1::Int) !=: postincr x) $ do
x +: "x1"
cdefaultCase $ do
x +: z
creturn $ x +: "x1"
block $ do
x' <- int "x"
boolCompl x'
x +: y
x =: y +: f[x, insertComment "Lorem ipsum dolor sit amet" y]
swap [x,y,z]
#include <stdio.h>
#include "foo.h"
typedef int* intptr_t;
int x;
double y;
intptr_t* z;
int w[21];
int const* a;
int* const b;
int f(int x1) {
switch (y) {
case (!x) : {
if (~x1) {
while (1/*Lorem ipsum*/ != x++) {
x + x1;
}
}
break;
}
default : {
x + z;
}
}
return x + x1
}
{
int x;
!x;
}
x + y;
x = y + f(x, y/*Lorem ipsum dolor sit amet*/);
int tmp;
tmp = x;
x = y;
y = z;
z = tmp;
{-# LANGUAGE TypeSynonymInstances #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE DeriveFunctor #-}
-- {-# LANGUAGE UndecidableInstances #-}
-- {-# LANGUAGE IncoherentInstances #-}
module Language.C.DSL where
import Control.Monad.Free
import Control.Exception.Base (bracket)
import Language.Format
-- import System.Console.ANSI
import System.IO (hPutStr, hPutStrLn, stderr)
import System.Process (readProcess)
data CStructUnionTag
= CStructTag
| CUnionTag
deriving (Eq,Show)
data Type
= CChar
| CSChar
| CUChar
| CShort
| CUShort
| CInt
| CUInt
| CLong
| CULong
| CLLong
| CULLong
| CFloat
| CDouble
| CLDouble
| CVoid
| CPtr Type
| CConst Type
| CRef Type
-- skip volatile for now
-- skip restrict for now
| OpaqueType Iden
-- these two need to be declared with TypeDecl
| CStructUnion CStructUnionTag Iden [(Type,Iden)]
| TypeDef Type Iden
deriving (Eq,Show)
type AddType = Iden -> Free Stmt Expr
data TypedIden = TypedIden {extractType::Type, extractIden::Iden} deriving Show
newtype Var = Var {unVar::TypedIden} deriving Show
instance Format Type where
fmt CChar = "char"
fmt CSChar = "signed char"
fmt CUChar = "unsigned char"
fmt CShort = "short"
fmt CUShort = "unsigned short"
fmt CInt = "int"
fmt CUInt = "unsigned int"
fmt CLong = "long"
fmt CULong = "unsigned long"
fmt CLLong = "long long"
fmt CULLong = "unsigned long long"
fmt CFloat = "float"
fmt CDouble = "double"
fmt CLDouble = "long double"
fmt CVoid = "void"
fmt (CPtr t) = concat [fmt t, "*"]
fmt (CRef t) = concat [fmt t, "&"]
fmt (CConst t) = unwords [fmt t, "const"]
fmt (CStructUnion CStructTag iden _) = unwords ["struct", iden]
fmt (CStructUnion CUnionTag iden _) = unwords ["union", iden]
fmt (TypeDef _ i) = i
fmt (OpaqueType i) = i
type Iden = String
data Fun = Fun TypedIden [Var]{-args-} (Free Stmt ()){-body-} deriving Show
unFree :: Free f a -> f (Free f a)
unFree (Free s) = s
class Declarable a where
extractDecl :: Stmt b -> a
instance Declarable Var where
extractDecl (VarDecl v _) = v
instance Declarable Type where
extractDecl (TypeDecl t _) = t
instance Declarable Fun where
extractDecl (FunDecl f _) = f
fun :: (Iden -> Free Stmt a) -> Iden -> [Free Stmt b] -> Free Stmt () -> Free Stmt ([Expr] -> Free Stmt ())
fun l i vs s = let
ti = unVar $ extractDecl $ unFree $ l i
f = Fun ti (map (extractDecl . unFree) vs) s
in liftF $ FunDecl f (\es -> stmt $ FunCall f es)
fmtArgList :: Format a => [a] -> String
fmtArgList = paren . commaSeparate . map fmt
data Op
= AddOp
| MulOp
| SubOp
| DivOp
| RemOp
| LtOp
| GtOp
| LeOp
| GeOp
| XorOp
| ComplementOp
| NotOp
| AndOp
| PreIncrOp
-- PostIncrOp goes in its own Expr so it gets its own Format case
| EqOp
| NeqOp
| NegateOp
| AddrOfOp
| DerefOp
deriving Show
instance Format Op where
fmt AddOp = "+"
fmt MulOp = "*"
fmt SubOp = "-"
fmt DivOp = "/"
fmt RemOp = "%"
fmt LtOp = "<"
fmt GtOp = ">"
fmt LeOp = "<="
fmt GeOp = ">="
fmt XorOp f= "^"
fmt ComplementOp = "~"
fmt NotOp = "!"
fmt AndOp = "&"
fmt PreIncrOp = "++"
fmt EqOp = "=="
fmt NeqOp = "!="
fmt AddrOfOp = "&"
fmt NegateOp = "-"
fmt DerefOp = "*"
data Expr
= VarExpr Var
| FunCall Fun [Expr]
| AssignExpr Expr Expr
| MemberAccess Expr Iden
| EmptyExpr
| CommentedExpr String Expr
| BinExpr Op Expr Expr
| UnExpr Op Expr
| PostIncrExpr Expr
| LitExpr String
| RawExpr String
deriving Show
class ExprLike e where
expr :: e -> Expr
instance ExprLike Var where expr = VarExpr
instance ExprLike Expr where expr = id
instance ExprLike (Free Stmt a) where
expr (Free (VarDecl v _)) = VarExpr v
expr (Free (ExprStmt e _)) = e
{- this is gnarly. i want:
- all Show a : expr a = LitExpr (show a)
- except String : expr str = RawExpr
- and to annotate string literals explicitly: cstrlit = LitExpr . show
-}
instance ExprLike String where
expr = RawExpr
instance ExprLike Int where
expr = LitExpr . show
-- instance ExprLike Double where
-- expr = LitExpr . show
instance ExprLike Char where
expr = LitExpr . show
cstrlit :: String -> Expr
cstrlit = LitExpr . show
stmt :: Expr -> Free Stmt ()
stmt e = liftF $ ExprStmt e ()
binExpr :: (ExprLike a, ExprLike b) => Op -> a -> b -> Free Stmt ()
binExpr op e e' = stmt $ BinExpr op (expr e) (expr e')
infixl 2 =:
(=:), (==:), (!=:), (+:), (<:), (>:), (<=:), (>=:), (-:), (*:), (/:), (^:), (&:) :: (ExprLike a, ExprLike b) => a -> b -> Free Stmt ()
(+:) = binExpr AddOp
(<:) = binExpr LtOp
(>:) = binExpr GtOp
(<=:) = binExpr LeOp
(>=:) = binExpr GeOp
(-:) = binExpr SubOp
(*:) = binExpr MulOp
(/:) = binExpr DivOp
(^:) = binExpr XorOp
(&:) = binExpr AndOp
(==:) = binExpr EqOp
(!=:) = binExpr NeqOp
e =: e' = stmt $ AssignExpr (expr e) (expr e')
unExpr :: ExprLike a => Op -> a -> Free Stmt ()
unExpr op e = stmt $ UnExpr op (expr e)
deref,addressof,preincr,postincr, bitCompl, boolCompl, (!:), (~:) :: ExprLike a => a -> Free Stmt ()
insertComment :: ExprLike a => String -> a -> Expr
insertComment x e = CommentedExpr x (expr e)
deref = unExpr DerefOp
addressof = unExpr AddrOfOp
preincr = unExpr PreIncrOp
postincr e = stmt $ PostIncrExpr (expr e)
(!:) = boolCompl
(~:) = bitCompl
boolCompl = unExpr NotOp
bitCompl = unExpr ComplementOp
data Include = Include String | LocalInclude String deriving Show
data Stmt next
= TypeDecl Type next
| ArrayDecl Var Int next
| BlockStmt (Free Stmt ()) next
| IfStmt Expr (Free Stmt ()) next
| WhileStmt Expr (Free Stmt ()) next
| VarDecl Var next
| FunDecl Fun next
| ExprStmt Expr next
| BreakStmt next
| SwitchStmt Expr (Free CaseStmt ()) next
| RawStmt String next
| IncludeStmt Include next
deriving (Show, Functor)
data CaseStmt a
= CaseStmt Expr (Free Stmt ()) a
-- | CaseBreakStmt Expr (Free Stmt ()) a
| DefaultStmt (Free Stmt ()) a
deriving (Show, Functor)
-- type declarations
typeDecl :: Type -> Free Stmt (Iden -> Free Stmt Expr)
typeDecl t = liftF $ TypeDecl t (\i -> varDecl t i)
struct :: Iden -> [(Type,Iden)] -> Free Stmt (Iden -> Free Stmt Expr)
cunion :: Iden -> [(Type,Iden)] -> Free Stmt (Iden -> Free Stmt Expr)
cunion membs = typeDecl . (CStructUnion CUnionTag membs)
struct membs = typeDecl . (CStructUnion CStructTag membs)
typedef :: (Iden -> Free Stmt a) -> Iden -> Free Stmt (Iden -> Free Stmt Expr)
typedef f iden = let
t = unsafeExtractType f
in typeDecl (TypeDef t iden)
-- typeDef :: (Iden -> Free Stmt a) -> Iden -> Free Stmt Expr
-- typeDef f = varDecl (unsafeExtractType f)
-- block statements
cif :: (ExprLike a) => a -> Free Stmt () -> Free Stmt ()
cif e s = liftF $ IfStmt (expr e) s ()
cwhile :: (ExprLike a) => a -> Free Stmt () -> Free Stmt ()
cwhile e s = liftF $ WhileStmt (expr e) s ()
block :: Free Stmt () -> Free Stmt ()
block s = liftF $ BlockStmt s ()
cswitch :: (ExprLike a) => a -> Free CaseStmt () -> Free Stmt ()
cswitch e s = liftF $ SwitchStmt (expr e) s ()
ccase :: (ExprLike a) => a -> Free Stmt () -> Free CaseStmt ()
ccase e s = liftF $ CaseStmt (expr e) s ()
ccaseBreak :: (ExprLike a) => a -> Free Stmt () -> Free CaseStmt ()
ccaseBreak e s = liftF $ CaseStmt (expr e) (do s; liftF $ BreakStmt ()) ()
cdefaultCase :: Free Stmt () -> Free CaseStmt ()
cdefaultCase s = liftF $ DefaultStmt s ()
-- misc statements
include :: String -> Free Stmt ()
include s = liftF $ IncludeStmt (Include s) ()
includeLocal :: String -> Free Stmt ()
includeLocal s = liftF $ IncludeStmt (LocalInclude s) ()
rawStmt :: String -> Free Stmt ()
rawStmt s = liftF $ RawStmt s ()
comment :: String -> Free Stmt ()
comment = rawStmt . surround "/*" "*/"
varDecl :: Type -> Iden -> Free Stmt Expr
varDecl t i = let
v = Var $ TypedIden t i
in liftF $ VarDecl v (VarExpr v)
unsafeExtractType :: (a -> Free Stmt b) -> Type
unsafeExtractType f = let
{- urgh such bad style
- my grandmother is probably rolling in her grave -}
app = f undefined
in extractType . unVar . extractDecl . unFree $ app
chgType :: (Type -> Type) -> (Iden -> Free Stmt a) -> Iden -> Free Stmt Expr
chgType l f i = let
t = unsafeExtractType f
in varDecl (l t) i
ptr,ref,cconst :: (Iden -> Free Stmt a) -> Iden -> Free Stmt Expr
ptr = chgType CPtr
ref = chgType CRef
cconst = chgType CConst
carray :: (Iden -> Free Stmt Expr) -> Iden -> Int -> Free Stmt Expr
carray f i sz = let
t = unsafeExtractType f
v = Var $ TypedIden t i
in liftF $ ArrayDecl v sz (expr v)
char, schar, uchar, short, ushort, int, uint :: Iden -> Free Stmt Expr
long, ulong, llong, ullong, float, double, ldouble :: Iden -> Free Stmt Expr
char = varDecl CChar
schar = varDecl CSChar
uchar = varDecl CUChar
short = varDecl CShort
ushort = varDecl CUShort
int = varDecl CInt
uint = varDecl CUInt
long = varDecl CLong
ulong = varDecl CULong
llong = varDecl CLLong
ullong = varDecl CULLong
float = varDecl CFloat
double = varDecl CDouble
ldouble = varDecl CLDouble
-- adds parentheses to inside expressions if needed to disambiguate
parenFmt :: Expr -> String
parenFmt e@(AssignExpr _ _) = paren . fmt $ e
parenFmt e@(BinExpr _ _ _) = paren . fmt $ e
parenFmt e@(UnExpr _ _) = paren . fmt $ e
parenFmt e = fmt e
instance Format Expr where
fmt (VarExpr (Var (TypedIden t i))) = i
fmt (FunCall (Fun (TypedIden _ i) _ _) es) = concat [i, paren $ commaSeparate (map fmt es)]
fmt (BinExpr op e e') = unwords [parenFmt e, fmt op, parenFmt e']
fmt (UnExpr op e) = concat [fmt op, parenFmt e]
fmt (PostIncrExpr e) = concat [parenFmt e, "++"]
fmt (RawExpr s) = s
fmt (CommentedExpr s e) = concat [fmt e, "/*",s,"*/"]
fmt (LitExpr s) = s
fmt (AssignExpr e e') = unwords [fmt e, "=", fmt e']
instance Format String where fmt = id
instance Format Var where fmt (Var ti) = fmt ti
instance Format TypedIden where fmt (TypedIden t i) = unwords [fmt t, i]
fmtBlock :: (Format a) => a -> String
fmtBlock = surround "{\n" "}" . indent . fmt
instance Format Fun where
fmt (Fun ti v s) = unwords
[ fmt ti ++ (paren $ commaSeparate $ map fmt v)
, fmtBlock s -- TODO make this BlockStmt
]
appendLine :: (Format a) => a -> (String -> String)
appendLine next = (++("\n"++(fmt next)))
instance Format Include where
fmt (Include s) = surround "<" ">" s
fmt (LocalInclude s) = surround "\"" "\"" s
instance (Format a) => Format (Free Stmt a) where
fmt (Free (VarDecl x next)) = appendLine next $ semicolon $ fmt x
fmt (Free (FunDecl x next)) = appendLine next $ fmt x
fmt (Free (ExprStmt x next)) = appendLine next $ semicolon $ fmt x
fmt (Free (BlockStmt x next)) = appendLine next $ fmtBlock x
fmt (Free (BreakStmt next)) = appendLine next $ "break;"
fmt (Free (RawStmt s next)) = appendLine next $ s
fmt (Free (IncludeStmt inc next)) = appendLine next $
unwords ["#include", fmt inc]
fmt (Free (TypeDecl (TypeDef t i) next)) = appendLine (fmt next) $
semicolon $ unwords ["typedef", fmt t, i]
fmt (Free (ArrayDecl v sz next)) = appendLine next $
semicolon $ concat [fmt v, "[", show sz, "]"]
fmt (Free (IfStmt cond body next)) = appendLine next $
unwords ["if", paren $ fmt cond, fmtBlock body]
fmt (Free (WhileStmt cond body next)) = appendLine next $
unwords ["while", paren $ fmt cond, fmtBlock body]
fmt (Free (SwitchStmt e s next)) = appendLine next $
unwords ["switch", paren $ fmt e, fmtBlock s]
fmt (Pure _) = ""
instance (Format a) => Format (Free CaseStmt a) where
fmt (Free (CaseStmt e s next)) = appendLine next $
unwords ["case", paren $ fmt e, ":", fmtBlock s]
fmt (Free (DefaultStmt s next)) = appendLine next $
unwords ["default", ":", fmtBlock s]
fmt (Pure _) = ""
instance Format () where fmt _ = ""
instance Format [()] where fmt _ = ""
foo = do
x <- int "x1"
y <- int "y1"
x +: y
swap vars = do
tmp <- int "tmp"
sequence $ zipWith (=:) (tmp:vars) (tail $ cycle $ tmp:vars)
compile :: (Format a) => FilePath -> FilePath -> Free Stmt a -> IO ()
compile intermediatePath outPath code = {-let
stdoutSGR = [SetColor Foreground Dull Red]
stderrSGR = [SetColor Foreground Dull Red]
putStdout str = do {hSetSGR stdout stdoutSGR; hPutStr stdout str}
putStderr str = do {hSetSGR stderr stderrSGR; hPutStr stderr str}
unsetColors = \_ -> do
hSetSGR stderr []
hSetSGR stdout []
in bracket (return ()) (do unsetColors) $ \_ -> -}
do
hPutStrLn stderr "codegen.."
let warningLine = "This code was auto-generated! DO NOT EDIT DIRECTLY!"
writeFile intermediatePath (fmt (do comment warningLine; code))
putStr $ fmt code
hPutStrLn stderr "g++.."
readProcess "g++" [intermediatePath, "-O2", "-Wall", "-g", "-o", outPath] ""
hPutStrLn stderr "done"
-- String utils for language related formatting tasks
module Language.Format
( surround
, paren
, singleQuote
, doubleQuote
, indent
, semicolon
, newline
, commaSeparate
, underscore
, Format(..)
, prettyPrint
)
where
import Data.List
import Data.Char (isSpace)
class Format a where
fmt :: a -> String
prettyPrint :: Format a => a -> IO ()
prettyPrint = putStr . rmBlankLines . fmt
surround :: String -> String -> String -> String
paren :: String -> String
singleQuote :: String -> String
doubleQuote :: String -> String
indent :: String -> String
semicolon :: String -> String
commaSeparate :: [String] -> String
newline :: String -> String
surround left right s = concat [left, s, right]
paren = surround "(" ")"
singleQuote = surround "'" "'"
doubleQuote = surround "\"" "\""
indent = unlines . map (" "++) . lines
rmBlankLines = unlines . filter (not . all isSpace) . lines
semicolon = (++";")
newline = (++"\n")
commaSeparate = intercalate ", "
underscore = intercalate "_" . words
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment