Created
September 27, 2012 17:36
-
-
Save osa1/3795313 to your computer and use it in GitHub Desktop.
kaleidoscope 3
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{-# OPTIONS_GHC -Wall #-} | |
{-# LANGUAGE NamedFieldPuns #-} | |
module Main where | |
import LLVM.Wrapper.Core | |
import LLVM.Wrapper.Analysis | |
import qualified LLVM.FFI.Core as LL | |
import qualified Data.Map as M | |
import Control.Monad.State | |
import Control.Monad.Error | |
data CompilerState = CompilerState | |
{ env :: M.Map String Value | |
, protos :: M.Map String FunProto | |
, builder :: Builder | |
, lmodule :: Module | |
} | |
type Compile = StateT CompilerState (ErrorT String IO) | |
data Expr | |
= NumberExpr String | |
| VariableExpr String | |
| BinaryExpr Char Expr Expr | |
| CallExpr String [Expr] | |
| PrototypeExpr FunProto | |
| FunctionExpr FunProto Expr | |
deriving Show | |
data FunProto = FunProto String [String] | |
deriving Show | |
codegen :: Expr -> Compile Value | |
codegen (NumberExpr num) = return $ constRealOfString doubleType num | |
codegen (VariableExpr var) = do | |
CompilerState{env} <- get | |
case M.lookup var env of | |
Nothing -> throwError $ "unknown variable name: " ++ var | |
Just v -> return v | |
codegen (BinaryExpr op left right) = do | |
leftval <- codegen left | |
rightval <- codegen right | |
CompilerState{builder} <- get | |
case op of | |
'+' -> liftIO $ buildAdd builder leftval rightval "addtmp" | |
'-' -> liftIO $ buildSub builder leftval rightval "addtmp" | |
'*' -> liftIO $ buildMul builder leftval rightval "multmp" | |
'<' -> do | |
-- convert bool 0/1 to double 0.0 or 1.0 | |
cmp <- liftIO $ buildFCmp builder LL.FPULT leftval rightval "cmptmp" | |
liftIO $ buildUIToFP builder cmp doubleType "booltmp" | |
inv -> throwError $ "invalid binary operator: " ++ [inv] | |
codegen (CallExpr callee args) = do | |
CompilerState{builder, lmodule} <- get | |
fun <- liftIO $ getNamedFunction lmodule callee | |
calleeFun <- case fun of | |
Nothing -> throwError $ "Unknown function referenced: " ++ callee | |
Just fn -> return fn | |
params <- liftIO $ getParams calleeFun | |
when (length params /= length args) (throwError "Incorrect # arguments passed.") | |
compiledArgs <- mapM codegen args | |
ret <- liftIO $ buildCall builder calleeFun compiledArgs "calltmp" | |
liftIO $ dumpValue ret | |
return ret | |
codegen (PrototypeExpr proto) = codegenProto proto | |
codegen (FunctionExpr proto body) = do | |
modify $ \state -> state{env=M.empty} | |
fun <- codegenProto proto | |
CompilerState{builder} <- get | |
bb <- liftIO $ appendBasicBlock fun "entry" | |
liftIO $ LL.positionAtEnd builder bb | |
retval <- codegen body | |
_ <- liftIO $ buildRet builder retval | |
verify <- liftIO $ verifyFunction fun | |
liftIO $ putStrLn "verify function called" | |
if verify | |
then do | |
liftIO $ dumpValue fun | |
return fun | |
else do | |
liftIO $ LL.deleteFunction fun | |
throwError "error while creating function" | |
codegenProto :: FunProto -> Compile Value | |
codegenProto (FunProto name args) = do | |
liftIO $ putStrLn "entering codegenProto" | |
state@CompilerState{lmodule} <- get | |
f <- liftIO $ getNamedFunction lmodule name | |
fun <- case f of | |
Nothing -> do | |
liftIO $ putStrLn "adding function" | |
liftIO $ addFunction lmodule name ft | |
Just fun -> do | |
liftIO $ putStrLn "returning existing function" | |
return fun | |
-- if fun already has a body, reject this. | |
blockCount <- liftIO $ LL.countBasicBlocks fun | |
when ((fromIntegral blockCount :: Int) /= 0) | |
(throwError $ "redefinition of function: " ++ name) | |
-- TODO: check parameter count | |
argvals <- liftIO $ getParams fun | |
liftIO $ mapM_ (\(val,name) -> setValueName val name) (zip argvals args) | |
-- TODO: update env ?? | |
put state{env=foldr (\(name,val) m -> M.insert name val m) (env state) (zip args argvals)} | |
liftIO $ dumpValue fun | |
return fun | |
where doubles = take (length args) (repeat doubleType) | |
ft = functionType doubleType doubles False -- FIXME: what is last argument ?? | |
main :: IO () | |
main = do | |
lmodule <- moduleCreateWithName "my jit compiler" | |
builder <- LL.createBuilder | |
LL.dumpModule lmodule | |
let expr = PrototypeExpr (FunProto "foo" ["a", "b"]) | |
expr2 = FunctionExpr (FunProto "foo" ["a", "b"]) (BinaryExpr '+' (VariableExpr "a") (VariableExpr "b")) | |
expr3 = CallExpr "foo" [(NumberExpr "123"), (NumberExpr "10")] | |
result <- runErrorT (runStateT (codegen expr2 >> codegen expr3) (CompilerState M.empty M.empty builder lmodule)) | |
case result of | |
Left err -> putStrLn err | |
Right r -> dumpValue (fst r) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment