Skip to content

Instantly share code, notes, and snippets.

@osa1
Created September 27, 2012 17:36
Show Gist options
  • Save osa1/3795313 to your computer and use it in GitHub Desktop.
Save osa1/3795313 to your computer and use it in GitHub Desktop.
kaleidoscope 3
{-# OPTIONS_GHC -Wall #-}
{-# LANGUAGE NamedFieldPuns #-}
module Main where
import LLVM.Wrapper.Core
import LLVM.Wrapper.Analysis
import qualified LLVM.FFI.Core as LL
import qualified Data.Map as M
import Control.Monad.State
import Control.Monad.Error
data CompilerState = CompilerState
{ env :: M.Map String Value
, protos :: M.Map String FunProto
, builder :: Builder
, lmodule :: Module
}
type Compile = StateT CompilerState (ErrorT String IO)
data Expr
= NumberExpr String
| VariableExpr String
| BinaryExpr Char Expr Expr
| CallExpr String [Expr]
| PrototypeExpr FunProto
| FunctionExpr FunProto Expr
deriving Show
data FunProto = FunProto String [String]
deriving Show
codegen :: Expr -> Compile Value
codegen (NumberExpr num) = return $ constRealOfString doubleType num
codegen (VariableExpr var) = do
CompilerState{env} <- get
case M.lookup var env of
Nothing -> throwError $ "unknown variable name: " ++ var
Just v -> return v
codegen (BinaryExpr op left right) = do
leftval <- codegen left
rightval <- codegen right
CompilerState{builder} <- get
case op of
'+' -> liftIO $ buildAdd builder leftval rightval "addtmp"
'-' -> liftIO $ buildSub builder leftval rightval "addtmp"
'*' -> liftIO $ buildMul builder leftval rightval "multmp"
'<' -> do
-- convert bool 0/1 to double 0.0 or 1.0
cmp <- liftIO $ buildFCmp builder LL.FPULT leftval rightval "cmptmp"
liftIO $ buildUIToFP builder cmp doubleType "booltmp"
inv -> throwError $ "invalid binary operator: " ++ [inv]
codegen (CallExpr callee args) = do
CompilerState{builder, lmodule} <- get
fun <- liftIO $ getNamedFunction lmodule callee
calleeFun <- case fun of
Nothing -> throwError $ "Unknown function referenced: " ++ callee
Just fn -> return fn
params <- liftIO $ getParams calleeFun
when (length params /= length args) (throwError "Incorrect # arguments passed.")
compiledArgs <- mapM codegen args
ret <- liftIO $ buildCall builder calleeFun compiledArgs "calltmp"
liftIO $ dumpValue ret
return ret
codegen (PrototypeExpr proto) = codegenProto proto
codegen (FunctionExpr proto body) = do
modify $ \state -> state{env=M.empty}
fun <- codegenProto proto
CompilerState{builder} <- get
bb <- liftIO $ appendBasicBlock fun "entry"
liftIO $ LL.positionAtEnd builder bb
retval <- codegen body
_ <- liftIO $ buildRet builder retval
verify <- liftIO $ verifyFunction fun
liftIO $ putStrLn "verify function called"
if verify
then do
liftIO $ dumpValue fun
return fun
else do
liftIO $ LL.deleteFunction fun
throwError "error while creating function"
codegenProto :: FunProto -> Compile Value
codegenProto (FunProto name args) = do
liftIO $ putStrLn "entering codegenProto"
state@CompilerState{lmodule} <- get
f <- liftIO $ getNamedFunction lmodule name
fun <- case f of
Nothing -> do
liftIO $ putStrLn "adding function"
liftIO $ addFunction lmodule name ft
Just fun -> do
liftIO $ putStrLn "returning existing function"
return fun
-- if fun already has a body, reject this.
blockCount <- liftIO $ LL.countBasicBlocks fun
when ((fromIntegral blockCount :: Int) /= 0)
(throwError $ "redefinition of function: " ++ name)
-- TODO: check parameter count
argvals <- liftIO $ getParams fun
liftIO $ mapM_ (\(val,name) -> setValueName val name) (zip argvals args)
-- TODO: update env ??
put state{env=foldr (\(name,val) m -> M.insert name val m) (env state) (zip args argvals)}
liftIO $ dumpValue fun
return fun
where doubles = take (length args) (repeat doubleType)
ft = functionType doubleType doubles False -- FIXME: what is last argument ??
main :: IO ()
main = do
lmodule <- moduleCreateWithName "my jit compiler"
builder <- LL.createBuilder
LL.dumpModule lmodule
let expr = PrototypeExpr (FunProto "foo" ["a", "b"])
expr2 = FunctionExpr (FunProto "foo" ["a", "b"]) (BinaryExpr '+' (VariableExpr "a") (VariableExpr "b"))
expr3 = CallExpr "foo" [(NumberExpr "123"), (NumberExpr "10")]
result <- runErrorT (runStateT (codegen expr2 >> codegen expr3) (CompilerState M.empty M.empty builder lmodule))
case result of
Left err -> putStrLn err
Right r -> dumpValue (fst r)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment