Skip to content

Instantly share code, notes, and snippets.

@y-yu
Created May 11, 2015 14:09
Show Gist options
  • Save y-yu/8fba01232bfcc05656c5 to your computer and use it in GitHub Desktop.
Save y-yu/8fba01232bfcc05656c5 to your computer and use it in GitHub Desktop.
Compiler Regular Expression to LLVM
sealed trait Regex
case object Empty extends Regex
case class Let(c: Char) extends Regex
case class Con(a: Regex, b: Regex) extends Regex
case class Alt(a: Regex, b: Regex) extends Regex
case class Star(a: Regex) extends Regex
sealed trait Value
case class RInt(n: Int) extends Value
case class RStr(s: String) extends Value
case class VInt(n: Int) extends Value
case class BA(f: String, l: Value) extends Value
sealed trait Type
case object I1 extends Type
case object I8 extends Type
case object I8P extends Type
case object I64 extends Type
case object I64P extends Type
sealed trait Cond
case object Eq extends Cond
sealed trait Inst
case class Label(n: Value) extends Inst
case class Assign(l: Value, r: Inst) extends Inst
case class Add(t: Type, v: Value, n: Int) extends Inst
case class Cmp(c: Cond, t: Type, a: Value, b: Value) extends Inst
case class Br1(d: Value) extends Inst
case class Br2(c: Value, t: Value, e: Value) extends Inst
case class Call(f: String, rt: Type, at: List[(Type, Value)]) extends Inst
case class Load(t: Type, p: Value) extends Inst
case class Store(vt: Type, v: Value, pt: Type, pv: Value) extends Inst
case class GetElementPtr(t: Type, v: Value, i: Value) extends Inst
object RegexToLLVM {
def nsp = RStr("sp")
def nstr = RStr("str")
def nmatch = RStr("match")
def nmiss = RStr("miss")
def fname = "@test"
def assign(r: Inst, n: Int): (Inst, Int) = (Assign(RInt(n), r), n + 1)
def mk_label(n: Int): (Inst, Int) = (Label(RInt(n)), n + 1)
def insts_of_regex(re: Regex): List[Inst] = {
def loop(re: Regex, n: Int): (List[Inst], Int) = re match {
case Empty => (Nil, n)
case Let(c) =>
val (i1, n1) = mk_label(n)
val (i2, n2) = assign(Load(I64P, nsp), n1)
val (i3, n3) = assign(GetElementPtr(I8P, nstr, RInt(n1)), n2)
val (i4, n4) = assign(Add(I64, RInt(n1), 1), n3)
val i5 = Store(I64, RInt(n3), I64P, nsp)
val (i6, n5) = assign(Load(I8P, RInt(n2)), n4)
val (i7, n6) = assign(Cmp(Eq, I8, RInt(n4), VInt(c.toInt)), n5)
val i8 = Br2(RInt(n5), RInt(n6), nmiss)
(List(i1, i2, i3, i4, i5, i6, i7, i8), n6)
case Con(a, b) =>
val (i1, n1) = loop(a, n)
val (i2, n2) = loop(b, n1)
(i1 ++ i2, n2)
case Alt(a, b) =>
val (i1, n1) = mk_label(n)
val (i2, n2) = assign(Load(I64P, nsp), n1)
val (i3, _) = assign(Call(fname, I1, List((I8P, BA(fname, RInt(n2 + 1))), (I64, RInt(n1)))), n2)
val (i4, n4) = loop(a, n2 + 1)
val (i5, n5) = mk_label(n4)
val (i6, n6) = loop(b, n5)
val i7 = Br1(RInt(n6))
val i8 = Br2(RInt(n2), nmatch, RInt(n5))
(List(i1, i2, i3, i8) ++ i4 ++ List(i5, i7) ++ i6, n6)
case Star(Star(r)) => loop(Star(r), n)
case Star(r) =>
val (i1, n1) = mk_label(n)
val (i2, n2) = assign(Load(I64P, nsp), n1)
val (i3, _) = assign(Call(fname, I1, List((I8P, BA(fname, RInt(n2 + 1))), (I64, RInt(n1)))), n2)
val (i4, n3) = loop(r, n2 + 1)
val (i5, n4) = mk_label(n3)
val i6 = Br1(RInt(n))
val i7 = Br2(RInt(n2), nmatch, RInt(n4))
(List(i1, i2, i3, i7) ++ i4 ++ List(i5, i6), n4)
}
val (i, n) = loop(re, 1)
i ++ List(Label(RInt(n)), Br1(nmatch))
}
def label_of_value(v: Value): String = v match {
case RInt(n) => n.toString
case _ => throw new Exception()
}
def var_of_value(v: Value): String = v match {
case RInt(n) => "%" + n
case RStr(s) => "%" + s
case VInt(n) => n.toString
case BA(f, l) => "blockaddress(" + f + ", " + var_of_value(l) + ")"
}
def pp_cond(c: Cond): String = c match {
case Eq => "eq"
}
def pp_type(t: Type): String = t match {
case I1 => "i1"
case I8 => "i8"
case I8P => "i8*"
case I64 => "i64"
case I64P => "i64*"
}
def align(t: Type): Int = t match {
case I8 => 1
case I8P => 8
case I64 => 8
case I64P => 8
case I1 => throw new Exception()
}
def pp_inst(i: Inst, tab: String = ""): String =
tab + (i match {
case Label(n) =>
"\n; <label>:" + label_of_value(n)
case Assign(l, r) =>
var_of_value(l) + " = " + pp_inst(r)
case Add(t, v, n) =>
"add nsw " + pp_type(t) + " " + var_of_value(v) + ", " + n
case Cmp(c, t, a, b) =>
"icmp " + pp_cond(c) + " " + pp_type(t) + " " + var_of_value(a) + ", " + var_of_value(b)
case Br1(d) =>
"br label " + var_of_value(d)
case Br2(c, t, e) =>
"br i1 " + var_of_value(c) + ", label " + var_of_value(t) + ", label " + var_of_value(e)
case Call(f, rt, a) =>
val s = a.foldLeft("i8* %str")((x, y) => x + ", " + pp_type(y._1) + " " + var_of_value(y._2))
"call " + pp_type(rt) + " " + f + "(" + s + ")"
case Load(t, p) =>
"load " + pp_type(t) + " " + var_of_value(p) + ", align " + align(t)
case Store(vt, v, pt, pv) =>
"store " + pp_type(vt) + " " + var_of_value(v) + ", " + pp_type(pt) + " " +
var_of_value(pv) + ", align " + align(vt)
case GetElementPtr(t, v, i) =>
"getelementptr inbounds " + pp_type(t) + " " + var_of_value(v) + ", i64 " + var_of_value(i)
})
def make(i: List[Inst]): String = {
val s =
"@.match = private unnamed_addr constant [7 x i8] c\"match\\0A\\00\", align 1\n" +
"@.unmatch = private unnamed_addr constant [9 x i8] c\"unmatch\\0A\\00\", align 1\n\n" +
"define i1 @test(i8* %str, i8* %l, i64 %sp_value) {\n" +
" %sp = alloca i64, align 8\n" +
" store i64 %sp_value, i64* %sp, align 8\n" +
" %isnull = icmp eq i8* %l, null\n" +
" br i1 %isnull, label %1, label %jump\n\n" +
"jump:\n" +
" indirectbr i8* %l, [" +
i.foldRight(List[String]())((x, y) => x match {
case Label(n) => ("label " + var_of_value(n)) :: y
case _ => y
}).mkString(", ") + "]\n"
val llvmir = i.map(pp_inst(_, " ")).foldLeft("")((x, y) => x + y + "\n")
val e =
"\nmiss:\n" +
" ret i1 0\n\n" +
"match:\n" +
" ret i1 1\n" +
"}\n\n" +
"define i32 @main(i32 %argc, i8** %argv) {\n" +
" %arg1 = getelementptr inbounds i8** %argv, i64 1\n" +
" %str = load i8** %arg1, align 8\n" +
" %res = call i1 @test(i8* %str, i8* null, i64 0)\n" +
" br i1 %res, label %match, label %unmatch\n\n" +
"match:\n" +
" call i32 (i8*, ...)* @printf(i8* getelementptr inbounds ([7 x i8]* @.match, i32 0, i32 0))\n" +
" br label %ret\n\n" +
"unmatch:\n" +
" call i32 (i8*, ...)* @printf(i8* getelementptr inbounds ([9 x i8]* @.unmatch, i32 0, i32 0))\n" +
" br label %ret\n\n" +
"ret:\n" +
" ret i32 0\n" +
"}\n\n" +
"declare i32 @printf(i8*, ...)"
s + llvmir + e
}
def main(args: Array[String]): Unit = {
// aa*bb*
val re1 = Con( Con(Let('a'), Star(Let('a'))), Con(Let('b'), Star(Let('b'))) )
// (a|b)c
val re2 = Con( Alt(Let('a'), Let('b')), Let('c') )
// (a*)*a
val re3 = Con(Star(Star(Let('a'))), Let('a'))
// aaaa
val re4 = Con( Con(Let('a'), Let('a')), Con(Let('a'), Let('a')) )
// (ba*ba*)*a
val re5 = Con(Star(Con(Con(Star(Let('a')), Let('b')), Star(Let('a')))), Let('a'))
// sa*(ba*ba*)*a*e
val re6 = Con(Con(Let('s'), Con(Con(Star(Let('a')), Star(Con(Con(Let('b'), Star(Let('a'))), Con(Let('b'), Star(Let('a')))))), Star(Let('a')))), Let('e'))
// s(ab*)*e
val re7 = Con(Con(Let('s'), Star(Con(Let('a'), Star(Let('b'))))), Let('e'))
// ba*b
val re8 = Con(Con(Let('b'), Star(Let('a'))), Let('b'))
val i = insts_of_regex(re8)
println(make(i))
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment