Created
June 16, 2020 04:11
-
-
Save BachiLi/30c7ada4fcba62fe220fd29597a31407 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from adt import ADT | |
from adt import memo as ADTMemo | |
import ast | |
import inspect | |
# Define the grammar | |
cohe = ADT(""" | |
module cohe { | |
float_expr = FloatConst ( float val ) | |
| FloatAdd ( float_expr lhs, float_expr rhs ) | |
int_expr = IntConst ( int val ) | |
| ThreadIdx ( ) | |
stmt = FloatAddTo ( string target, int_expr index, float_expr e ) | |
block = Block ( stmt* s ) | |
kernel = Kernel ( string name, string* in_args, string* out_args, block body ) | |
} | |
""") | |
# Inject memoization | |
ADTMemo(cohe, ['FloatConst', | |
'FloatAdd', | |
'IntConst', | |
'ThreadIdx', | |
'FloatAddTo', | |
'Block', | |
'Kernel']) | |
# Utilities | |
def is_const(expr): | |
return isinstance(expr, cohe.FloatConst) or \ | |
isinstance(expr, cohe.IntConst) or \ | |
isinstance(expr, cohe.ThreadIdx) | |
def get_const(expr): | |
if isinstance(expr, cohe.FloatConst) or \ | |
isinstance(expr, cohe.IntConst): | |
return expr.val | |
elif isinstance(expr, cohe.ThreadIdx): | |
return 'thread_index' | |
class Input: pass | |
class Output: pass | |
def parse(kernel): | |
""" | |
Given a Python function kernel, parse it to a cohe AST. | |
""" | |
# kernel_globals = kernel.__globals__ | |
def visit_FunctionDef(node): | |
kernel_name = node.name | |
args = node.args | |
assert(args.vararg is None) | |
assert(args.kwarg is None) | |
args = args.args | |
# Check & load the arguments | |
inputs = [] | |
outputs = [] | |
for arg in args: | |
assert(arg.annotation is not None) | |
if arg.annotation.id == 'Input': | |
inputs.append(arg.arg) | |
elif arg.annotation.id == 'Output': | |
outputs.append(arg.arg) | |
else: | |
assert(False) | |
body = [] | |
for b in node.body: | |
s = visit_stmt(b) | |
if s is not None: | |
body.append(s) | |
return cohe.Kernel(kernel_name, inputs, outputs, cohe.Block(body)) | |
def visit_expr(node): | |
if isinstance(node, ast.Call): | |
if isinstance(node.func, ast.Name): | |
name = node.func.id | |
elif isinstance(node.func, ast.Attribute): | |
name = node.func.attr | |
else: | |
assert False, f'Unknown Call node function {type(node.func).__name__}' | |
if name == 'ThreadIdx': | |
return cohe.ThreadIdx() | |
else: | |
assert False, 'Unimplement function call' | |
elif isinstance(node, ast.Num): | |
if isinstance(node.n, int): | |
return cohe.IntConst(node.n) | |
elif isinstance(node.n, float): | |
return cohe.FloatConst(node.n) | |
else: | |
assert False, f'Unknown Num.n {type(node.n)}' | |
else: | |
assert False, f'Unknown expr {type(node).__name__}' | |
def visit_lhs(node): | |
if isinstance(node, ast.Subscript): | |
assert isinstance(node.slice, ast.Index) | |
return node.value.id, visit_expr(node.slice.value) | |
else: | |
assert False, f'Unknown left hand side {type(node).__name__}' | |
def visit_stmt(node): | |
if isinstance(node, ast.AugAssign): | |
target, index = visit_lhs(node.target) | |
assert isinstance(node.op, ast.Add), 'Only += is supported' | |
value = visit_expr(node.value) | |
return cohe.FloatAddTo(target, index, value) | |
else: | |
assert False, f'Unknown statement {type(node).__name__}' | |
module = ast.parse(inspect.getsource(kernel)) | |
assert(len(module.body) == 1) | |
assert(type(module.body[0]) == ast.FunctionDef) | |
return visit_FunctionDef(module.body[0]) | |
# Codegen | |
class Codegen: | |
""" | |
cohe AST to ispc code | |
""" | |
def __init__(self): | |
self.expr_dict = {} | |
self.code = '' | |
self.tab_count = 0 | |
def get_handle(self, expr): | |
if is_const(expr): | |
return get_const(expr) | |
else: | |
return self.expr_dict[expr] | |
def emit_tabs(self): | |
self.code += '\t' * self.tab_count | |
def emit_expr(self, expr): | |
if isinstance(expr, cohe.FloatConst) or \ | |
isinstance(expr, cohe.IntConst) or \ | |
isinstance(expr, cohe.ThreadIdx): | |
# Const is always inlined to the expression | |
pass | |
elif isinstance(expr, cohe.FloatAdd): | |
if expr in self.expr_dict: | |
# Skip generated exprs | |
pass | |
self.emit_expr(expr.lhs) | |
self.emit_expr(expr.rhs) | |
lhs = self.get_handle(expr.lhs) | |
rhs = self.get_handle(expr.rhs) | |
expr_id = len(self.expr_dict) | |
self.expr_dict[expr] = expr_id | |
self.emit_tabs() | |
self.code += f'float _t{expr_id} = {lhs} + {rhs};\n' | |
def emit_stmt(self, stmt): | |
assert(isinstance(stmt, cohe.stmt)) | |
assert(isinstance(stmt, cohe.FloatAddTo)) | |
self.emit_expr(stmt.e) | |
self.emit_expr(stmt.index) | |
self.emit_tabs() | |
e = self.get_handle(stmt.e) | |
index = self.get_handle(stmt.index) | |
self.code += f'{stmt.target}[{index}] += {e};\n' | |
def emit_block(self, block): | |
assert(isinstance(block, cohe.block)) | |
for stmt in block.s: | |
self.emit_stmt(stmt) | |
def emit_kernel(self, kernel): | |
assert(isinstance(kernel, cohe.kernel)) | |
assert(len(kernel.in_args + kernel.out_args) > 0) | |
self.code += f'void {kernel.name}(' | |
for i, arg in enumerate(kernel.in_args + kernel.out_args): | |
if i > 0: | |
self.code += ', ' | |
self.code += 'uniform float *' + arg | |
self.code += ', uniform int num_threads' | |
self.code += ') {\n' | |
self.tab_count += 1 | |
self.emit_tabs() | |
self.code += 'foreach (thread_index = 0 ... num_threads) {\n' | |
self.tab_count += 1 | |
self.emit_block(kernel.body) | |
self.tab_count -= 1 | |
self.emit_tabs() | |
self.code += '}\n' | |
self.tab_count -= 1 | |
self.code += '}\n' | |
# Program | |
def foo(out: Output): | |
out[cohe.ThreadIdx()] += 1.0 | |
prog = parse(foo) | |
cg = Codegen() | |
cg.emit_kernel(prog) | |
print(cg.code) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment