Created
August 7, 2022 07:06
-
-
Save jackalcooper/223eafe3f592c034d7eaa9fffdbcf44d to your computer and use it in GitHub Desktop.
Some examples
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
defmodule CfTest do | |
use ExUnit.Case | |
use Beaver | |
alias Beaver.MLIR | |
alias Beaver.MLIR.Type | |
alias Beaver.MLIR.Attribute | |
alias Beaver.MLIR.Dialect.{CF, Arith} | |
@moduletag :smoke | |
defmodule MutCompiler do | |
use Beaver | |
require Beaver.MLIR | |
defmodule Acc do | |
@enforce_keys [:vars, :block, :region] | |
defstruct vars: %{}, block: nil, region: nil | |
end | |
# 3. use the var in acc | |
defp fetch_var(%Acc{vars: vars}, {var_name, _line, nil}) do | |
found = Map.get(vars, var_name) | |
if is_nil(found) do | |
:not_found | |
else | |
{:ok, found} | |
end | |
end | |
defp put_var(%Acc{vars: vars} = acc, {var_name, _line, nil}, var) do | |
%Acc{acc | vars: Map.put(vars, var_name, var)} | |
end | |
defp update_block(%Acc{block: _old} = acc, block) do | |
%Acc{acc | block: block} | |
end | |
# 1. starts with "root", the return expression | |
defp gen_mlir( | |
{:return, _line, [arg]}, | |
%Acc{block: block} = acc | |
) do | |
# we expect it to be a MLIR Value | |
{arg = %MLIR.Value{}, acc} = gen_mlir(arg, acc) | |
mlir = | |
mlir block: block do | |
Func.return(arg) >>> [] | |
end | |
{mlir, acc} | |
end | |
# found {:base_lr, [line: 89], nil} unmatched, so we add this match to extract MLIR Value | |
defp gen_mlir({_var_name, _line, nil} = ast, acc) do | |
# {ast, acc} | |
with {:ok, found} <- fetch_var(acc, ast) do | |
{found, acc} | |
else | |
:not_found -> | |
raise "block arg not found, #{inspect(ast)}" | |
{gen_mlir(ast, acc), acc} | |
end | |
end | |
# For `:if` it is kind of tricky, we need to generate block for it | |
defp gen_mlir( | |
{:if, _, [cond_ast, [do: do_block_ast, else: else_block_ast]]}, | |
%Acc{region: region, block: entry} = acc | |
) do | |
{condition, acc} = gen_mlir(cond_ast, acc) | |
bb_next = | |
mlir do | |
block bb_next(arg >>> Type.f32()) do | |
end | |
end | |
true_branch = | |
mlir do | |
block _true_branch() do | |
{%MLIR.Value{} = mlir, acc} = gen_mlir(do_block_ast, acc) | |
%MLIR.CAPI.MlirBlock{} = Beaver.MLIR.__BLOCK__() | |
CF.br({bb_next, [mlir]}) >>> [] | |
end | |
end | |
false_branch = | |
mlir do | |
block _false_branch() do | |
{%MLIR.Value{} = mlir, acc} = gen_mlir(else_block_ast, acc) | |
CF.br({bb_next, [mlir]}) >>> [] | |
end | |
end | |
mlir block: entry do | |
CF.cond_br(condition, true_branch, false_branch) >>> [] | |
end | |
Beaver.MLIR.CAPI.mlirRegionAppendOwnedBlock(region, true_branch) | |
Beaver.MLIR.CAPI.mlirRegionAppendOwnedBlock(region, false_branch) | |
Beaver.MLIR.CAPI.mlirRegionAppendOwnedBlock(region, bb_next) | |
{arg, update_block(acc, bb_next)} | |
end | |
# an assign, it is different from binding in Elixir, so we want to generate IR of mutable semantic | |
# but this could be complicated, we can simply call gen_mlir with the ast of the var for now | |
defp gen_mlir( | |
{:=, _, | |
[ | |
{_name, _, nil} = ast, | |
var | |
]}, | |
acc | |
) do | |
{mlir, acc} = gen_mlir(var, acc) | |
{mlir, put_var(acc, ast, mlir)} | |
end | |
# in real world, this should be merged with the gen_mlir for :+ | |
defp gen_mlir({:<, _, [left, right]}, %Acc{block: block} = acc) do | |
{left = %MLIR.Value{}, acc} = gen_mlir(left, acc) | |
{right = %MLIR.Value{}, acc} = gen_mlir(right, acc) | |
less = | |
mlir block: block do | |
Arith.cmpf(left, right, predicate: Attribute.integer(Type.i64(), 0)) >>> Type.i1() | |
end | |
{less, acc} | |
end | |
# after adding this, you should see this kind of IR printed | |
# %0 = arith.mulf %arg2, %arg1 : f32 | |
defp gen_mlir({:*, _line, [left, right]}, %Acc{block: block} = acc) do | |
{left = %MLIR.Value{}, acc} = gen_mlir(left, acc) | |
{right = %MLIR.Value{}, acc} = gen_mlir(right, acc) | |
# we only work with float 32 for now | |
add = | |
mlir block: block do | |
Arith.mulf(left, right) >>> Type.f32() | |
end | |
{add, acc} | |
end | |
# at some point you should see logging of node not supported for MLIR CAPI Value, | |
# let's add the match to disable this kind of logging | |
defp gen_mlir(%Beaver.MLIR.Value{} = mlir, acc) do | |
{mlir, acc} | |
end | |
# you might want to print the node not matched | |
defp gen_mlir(ast, acc) do | |
# IO.inspect(ast, label: "node not matched") | |
{ast, acc} | |
end | |
def gen_func(call, block) do | |
# TODO: generate the args | |
{name, _args} = Macro.decompose_call(call) | |
mlir do | |
module do | |
Func.func some_func( | |
sym_name: "\"#{name}\"", | |
function_type: Type.function(List.duplicate(Type.f(32), 4), [Type.f(32)]) | |
) do | |
region do | |
block bb_entry( | |
total_iters >>> Type.f32(), | |
factor >>> Type.f(32), | |
base_lr >>> Type.f(32), | |
step >>> Type.f(32) | |
) do | |
# Put the MLIR Values for args into a Map | |
vars = %{total_iters: total_iters, factor: factor, base_lr: base_lr, step: step} | |
acc = %Acc{ | |
vars: vars, | |
region: Beaver.MLIR.__REGION__(), | |
block: Beaver.MLIR.__BLOCK__() | |
} | |
# keep generating until we meet a terminator | |
{_mlir, _acc} = | |
Macro.prewalk(block, acc, fn ast, %Acc{} = acc -> | |
gen_mlir(ast, acc) | |
end) | |
end | |
end | |
end | |
end | |
end | |
# we let MLIR verify the generated IR for us, so it gonna be legit! | |
|> MLIR.Operation.verify!(dump_if_fail: true) | |
end | |
# In most of LLVM or other compiler guidance, it starts with ast parsing. | |
# In Elixir we don't have to, just reuse the Elixir ast and use macro to do the magic! | |
defmacro defnative(call, do: block) do | |
mlir_asm = | |
MutCompiler.gen_func(call, block) | |
|> MLIR.to_string() | |
quote do | |
alias MLIR.Dialect.Func | |
unquote(Macro.escape(mlir_asm)) | |
# TODO: return a function capturing the JIT | |
# TODO: show how to canonicalize the IR and fold some computation to constants | |
end | |
end | |
end | |
test "cf with mutation" do | |
import MutCompiler | |
mlir = | |
defnative get_lr(total_iters, factor, base_lr, step) do | |
base_lr = base_lr * factor | |
return(base_lr) | |
end | |
assert mlir =~ "%0 = arith.mulf %arg2, %arg1 : f32", mlir | |
mlir = | |
defnative get_lr_with_ctrl_flow(total_iters, factor, base_lr, step) do | |
base_lr = | |
if step < total_iters do | |
base_lr * factor | |
else | |
base_lr | |
end | |
return(base_lr) | |
end | |
assert mlir =~ "%1 = arith.mulf %arg2, %arg1 : f32", mlir | |
assert mlir =~ "return %2 : f32", mlir | |
end | |
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
defmodule Manx.Defn do | |
alias __MODULE__.Env | |
alias Beaver.MLIR | |
import MLIR.Sigils | |
import Beaver, only: :macros | |
require Beaver.MLIR | |
alias MLIR.{Type, Attribute} | |
def gen_type({:u, size}), do: Type.i(size) | |
def gen_type({:s, size}), do: Type.i(size) | |
def gen_type({:f, size}), do: Type.f(size) | |
def gen_type({:c, size}), do: Type.complex(Type.f(div(size, 2))) | |
def gen_type(%Nx.Tensor{shape: shape, type: type}) do | |
Tuple.to_list(shape) | |
|> Type.ranked_tensor(gen_type(type)) | |
end | |
def gen_type(tuple) when is_tuple(tuple) do | |
Tuple.to_list(tuple) | |
|> Enum.map(&gen_type/1) | |
|> Type.tuple() | |
end | |
@doc """ | |
In upstream MLIR, there is no lower-able Op packing multiple values into a tuple. | |
If the Nx root type is a tuple, it should be converted to multi-results. | |
This function should always return a list of types | |
""" | |
def gen_root_types(tuple) when is_tuple(tuple) do | |
Tuple.to_list(tuple) | |
|> Enum.map(&gen_type/1) | |
end | |
def gen_root_types(type), do: [gen_type(type)] | |
defp gen_affine_map(shape) do | |
import MLIR.AffineMap | |
rank = tuple_size(shape) | |
exprs = | |
shape | |
|> Tuple.to_list() | |
|> Enum.with_index() | |
|> Enum.map(fn | |
{1, _index} -> 0 | |
{dim_size, index} when dim_size > 1 -> dim(index) | |
end) | |
MLIR.AffineMap.create(rank, 0, exprs) | |
end | |
defp expand_for_output(input_shape, output_shape) | |
when tuple_size(output_shape) >= tuple_size(input_shape) do | |
output_rank = tuple_size(output_shape) | |
rank = tuple_size(input_shape) | |
expanded = List.duplicate(1, output_rank - rank) ++ Tuple.to_list(input_shape) | |
List.to_tuple(expanded) | |
end | |
defp gen_indexing_maps(input1_shape, out_shape) do | |
[ | |
expand_for_output(input1_shape, out_shape) |> gen_affine_map(), | |
gen_affine_map(out_shape) | |
] | |
|> Enum.map(&MLIR.Attribute.affine_map/1) | |
|> Attribute.array() | |
end | |
defp gen_indexing_maps( | |
input1_shape, | |
input2_shape, | |
out_shape | |
) do | |
[ | |
expand_for_output(input1_shape, out_shape) |> gen_affine_map(), | |
expand_for_output(input2_shape, out_shape) |> gen_affine_map(), | |
gen_affine_map(out_shape) | |
] | |
|> Enum.map(&MLIR.Attribute.affine_map/1) | |
|> Attribute.array() | |
end | |
defp gen_iterator_types({}, {}) do | |
~a{[]} | |
end | |
defp gen_iterator_types({_}, {_}) do | |
~a{["parallel"]} | |
end | |
defp gen_iterator_types(_input1, _input2, _output) do | |
~a{["parallel", "parallel"]} | |
end | |
defp gen_expand( | |
_env, | |
value, | |
%{shape: in_shape}, | |
%{shape: out_shape} | |
) | |
when tuple_size(in_shape) == tuple_size(out_shape) do | |
value | |
end | |
defp gen_expand( | |
%Env{block: block}, | |
value, | |
%{type: type, shape: in_shape} = input_t, | |
%{shape: out_shape} = _output_t | |
) do | |
mlir block: block do | |
shape = expand_for_output(in_shape, out_shape) | |
t = %{input_t | type: type, shape: shape} | |
rank_diff = tuple_size(out_shape) - tuple_size(in_shape) | |
pairs = | |
Range.new(0, tuple_size(in_shape) - 1, 1) | |
|> Enum.map(fn i -> [i, i + rank_diff] end) | |
Tensor.expand_shape(value, reassociation: Tensor.reassociation(pairs)) >>> gen_type(t) | |
end | |
end | |
def gen_op(%Env{block: block}, %Nx.Tensor{ | |
data: %Nx.Defn.Expr{op: :parameter, args: [pos]} | |
}) | |
when is_integer(pos) do | |
arg = block |> Beaver.MLIR.Block.get_arg!(pos) | |
arg_cnt = Beaver.Walker.arguments(block) |> Enum.count() | |
if pos >= arg_cnt do | |
raise "arg ##{pos} out of bound, arg_cnt: #{arg_cnt}" | |
end | |
if MLIR.is_null(arg) do | |
raise "arg ##{pos} not found" | |
end | |
arg | |
end | |
def gen_op( | |
%Env{block: block}, | |
%Nx.Tensor{ | |
data: %Nx.Defn.Expr{op: :constant, args: [:nan]}, | |
shape: {}, | |
type: {:f, 32} | |
} = t | |
) do | |
mlir block: block do | |
TOSA.const({:value, ~a{dense<0x7F800001> : tensor<f32>}}) >>> gen_type(t) | |
end | |
end | |
def gen_op( | |
%Env{block: block}, | |
%Nx.Tensor{ | |
data: %Nx.Defn.Expr{op: :constant, args: [:infinity]}, | |
shape: {}, | |
type: {:f, 32} | |
} = t | |
) do | |
mlir block: block do | |
TOSA.const({:value, ~a{dense<0x7F800000> : tensor<f32>}}) >>> | |
gen_type(t) | |
end | |
end | |
def gen_op( | |
%Env{block: block}, | |
%Nx.Tensor{ | |
data: %Nx.Defn.Expr{op: :constant, args: [:neg_infinity]}, | |
shape: {}, | |
type: {:f, 32} | |
} = t | |
) do | |
mlir block: block do | |
_r = | |
TOSA.const({:value, ~a{dense<0xFF800000> : tensor<f32>}}) >>> | |
gen_type(t) | |
end | |
end | |
def gen_op( | |
%Env{block: block}, | |
%Nx.Tensor{ | |
data: %Nx.Defn.Expr{op: :constant, args: [value]}, | |
shape: {} | |
} = t | |
) | |
when is_integer(value) or is_float(value) do | |
mlir block: block do | |
t_str = gen_type(t) |> MLIR.to_string() | |
TOSA.const({:value, ~a{dense<#{value}> : #{t_str}}}) >>> | |
gen_type(t) | |
end | |
end | |
def gen_op( | |
%Env{block: block}, | |
%Nx.Tensor{ | |
data: %Nx.Defn.Expr{op: :constant, args: [%Complex{im: im, re: re}]}, | |
type: {:c, 64} | |
} = t | |
) do | |
mlir block: block do | |
t_str = gen_type(t) |> MLIR.to_string() | |
Arith.constant({:value, ~a[dense<(#{re}, #{im})> : #{t_str}]}) >>> | |
gen_type(t) | |
end | |
end | |
def gen_op( | |
%Env{block: block}, | |
%Nx.Tensor{ | |
data: %Nx.Defn.Expr{ | |
args: [%Nx.Tensor{data: %Nx.BinaryBackend{state: binary}}], | |
op: :tensor | |
} | |
} = t | |
) do | |
mlir block: block do | |
tensor_attr = | |
MLIR.CAPI.mlirDenseElementsAttrRawBufferGet( | |
gen_type(t), | |
byte_size(binary), | |
Beaver.Native.c_string(binary) |> Beaver.Native.Array.as_opaque() | |
) | |
if MLIR.Attribute.is_null(tensor_attr), do: raise("fail to parse tensor dense elements") | |
TOSA.const({:value, tensor_attr}) >>> gen_type(t) | |
end | |
end | |
# unary tosa | |
def gen_op( | |
%Env{block: block} = env, | |
%Nx.Tensor{data: %Nx.Defn.Expr{op: op, args: [input1]}} = t | |
) | |
when op in [ | |
:negate, | |
:abs, | |
:bitwise_not, | |
:exp, | |
:logical_not, | |
:log, | |
:tanh, | |
:rsqrt, | |
:is_nan, | |
:is_infinity, | |
:sigmoid | |
] do | |
mlir block: block do | |
input1_value = gen_op(env, input1) | |
input1_value = TOSA.cast(input1_value) >>> gen_type(%{input1 | type: t.type}) | |
case op do | |
:negate -> | |
TOSA.negate(input1_value) >>> gen_type(t) | |
:abs -> | |
TOSA.abs(input1_value) >>> gen_type(t) | |
:bitwise_not -> | |
TOSA.bitwise_not(input1_value) >>> gen_type(t) | |
:logical_not -> | |
input1_value = TOSA.cast(input1_value) >>> gen_type(%{t | type: {:u, 1}}) | |
result = TOSA.logical_not(input1_value) >>> gen_type(%{t | type: {:u, 1}}) | |
TOSA.cast(result) >>> gen_type(t) | |
:exp -> | |
TOSA.exp(input1_value) >>> gen_type(t) | |
:log -> | |
TOSA.log(input1_value) >>> gen_type(t) | |
:tanh -> | |
TOSA.tanh(input1_value) >>> gen_type(t) | |
:rsqrt -> | |
TOSA.rsqrt(input1_value) >>> gen_type(t) | |
:sigmoid -> | |
TOSA.sigmoid(input1_value) >>> gen_type(t) | |
:is_nan -> | |
c = TOSA.equal(input1_value, input1_value) >>> gen_type(%{t | type: {:u, 1}}) | |
c = TOSA.logical_not(c) >>> gen_type(%{t | type: {:u, 1}}) | |
TOSA.cast(c) >>> gen_type(t) | |
:is_infinity -> | |
input1_value = gen_op(env, input1) | |
input1_type_str = gen_type(input1) |> MLIR.to_string() | |
inf = | |
TOSA.const({:value, ~a{dense<0x7F800000> : #{input1_type_str}}}) >>> gen_type(input1) | |
abs = TOSA.abs(input1_value) >>> gen_type(input1) | |
equal = TOSA.equal(inf, abs) >>> gen_type(%{t | type: {:u, 1}}) | |
TOSA.cast(equal) >>> gen_type(t) | |
end | |
end | |
end | |
def gen_op( | |
env, | |
%Nx.Tensor{shape: {}, data: %Nx.Defn.Expr{op: :all, args: [%{shape: {}} = input1, _]}} | |
) do | |
gen_op(env, input1) | |
end | |
def gen_op( | |
%Env{block: block} = env, | |
%Nx.Tensor{ | |
data: %Nx.Defn.Expr{ | |
op: :all, | |
args: [%{shape: in_shape} = input1, [axes: axes, keep_axes: keep_axes]] | |
} | |
} = t | |
) | |
when is_list(axes) do | |
mlir block: block do | |
input1 = gen_op(env, input1) | |
input1 = TOSA.cast(input1) >>> gen_type(%{t | shape: in_shape, type: {:u, 1}}) | |
{in_shape, mlir_value} = | |
Enum.reduce( | |
axes, | |
{Tuple.to_list(in_shape), input1}, | |
fn axis, {in_shape, mlir_value} -> | |
out_shape = List.replace_at(in_shape, axis, 1) | |
reduced = | |
TOSA.reduce_all(mlir_value, axis: Attribute.integer(Type.i64(), axis)) >>> | |
gen_type(%{t | shape: List.to_tuple(out_shape), type: {:u, 1}}) | |
{out_shape, reduced} | |
end | |
) | |
mlir_value = TOSA.cast(mlir_value) >>> gen_type(%{t | shape: List.to_tuple(in_shape)}) | |
if keep_axes do | |
mlir_value | |
else | |
Tensor.collapse_shape(mlir_value, reassociation: Tensor.reassociation([])) >>> gen_type(t) | |
end | |
end | |
end | |
def gen_op( | |
%Env{block: block} = env, | |
%Nx.Tensor{ | |
data: | |
%Nx.Defn.Expr{ | |
op: :all, | |
args: [%{shape: in_shape} = input1, [axes: nil, keep_axes: keep_axes]] | |
} = expr | |
} = t | |
) do | |
# if axes is nil, replace it with a list of every axis | |
mlir block: block do | |
rank = tuple_size(in_shape) | |
axes = Range.new(0, rank - 1, 1) |> Enum.to_list() | |
expr = %{ | |
expr | |
| args: [input1, [axes: axes, keep_axes: keep_axes]] | |
} | |
gen_op(env, %{t | data: expr}) | |
end | |
end | |
def gen_op( | |
%Env{block: block} = env, | |
%Nx.Tensor{ | |
data: %Nx.Defn.Expr{ | |
op: :conjugate, | |
args: [%Nx.Tensor{type: {:c, 64}} = complex_tensor] | |
}, | |
shape: {} | |
} = t | |
) do | |
mlir block: block do | |
complex_tensor = gen_op(env, complex_tensor) | |
complex_element = Tensor.extract(complex_tensor) >>> Type.complex(Type.f32()) | |
conjugate_element = Complex.conj(complex_element) >>> Type.complex(Type.f32()) | |
conjugate_tensor = | |
Bufferization.alloc_tensor(operand_segment_sizes: ODS.operand_segment_sizes([0, 0])) >>> | |
gen_type(t) | |
Tensor.insert(conjugate_element, conjugate_tensor) >>> | |
gen_type(t) | |
end | |
end | |
def gen_op( | |
%Env{block: block} = env, | |
%Nx.Tensor{ | |
data: %Nx.Defn.Expr{op: :conjugate, args: [%Nx.Tensor{} = real_tensor]}, | |
shape: {}, | |
type: complex_type = {:c, 64} | |
} = t | |
) do | |
mlir block: block do | |
real_tensor = gen_op(env, real_tensor) | |
real_tensor = TOSA.cast(real_tensor) >>> Type.ranked_tensor([], Type.f32()) | |
real = Tensor.extract(real_tensor) >>> Type.f32() | |
conjugate_tensor = | |
Bufferization.alloc_tensor(operand_segment_sizes: ODS.operand_segment_sizes([0, 0])) >>> | |
gen_type(t) | |
imaginary = Arith.constant(value: Attribute.float(Type.f32(), 0.0)) >>> Type.f32() | |
complex_element_t = gen_type(complex_type) | |
complex_element = Complex.create(real, imaginary) >>> complex_element_t | |
conjugate_element = Complex.conj(complex_element) >>> complex_element_t | |
_ = Tensor.insert(conjugate_element, conjugate_tensor) >>> gen_type(t) | |
end | |
end | |
def gen_op( | |
%Env{block: block} = env, | |
%Nx.Tensor{ | |
data: %Nx.Defn.Expr{op: :conjugate, args: [complex_tensor]}, | |
shape: shape | |
} = t | |
) do | |
mlir block: block do | |
element_cnt = Enum.reduce(Tuple.to_list(shape), 1, &*/2) | |
complex_tensor = gen_op(env, complex_tensor) | |
lower = Arith.constant(value: Attribute.integer(Type.index(), 0)) >>> Type.index() | |
upper = Arith.constant(value: Attribute.integer(Type.index(), element_cnt)) >>> Type.index() | |
step = Arith.constant(value: Attribute.integer(Type.index(), 1)) >>> Type.index() | |
conjugate_tensor = | |
Bufferization.alloc_tensor(operand_segment_sizes: ODS.operand_segment_sizes([0, 0])) >>> | |
gen_type(t) | |
conjugate_memref = | |
Bufferization.to_memref(conjugate_tensor) >>> | |
Type.memref([2], Type.complex(Type.f32())) | |
SCF.for [lower, upper, step] do | |
region do | |
block inner(index >>> Type.index()) do | |
complex_element = Tensor.extract(complex_tensor, index) >>> Type.complex(Type.f32()) | |
conjugate_element = Complex.conj(complex_element) >>> Type.complex(Type.f32()) | |
MemRef.store(conjugate_element, conjugate_memref, index) >>> [] | |
SCF.yield() >>> [] | |
end | |
end | |
end >>> [] | |
conjugate_tensor | |
end | |
end | |
def gen_op( | |
%Env{block: block} = env, | |
%Nx.Tensor{ | |
data: %Nx.Defn.Expr{ | |
op: :imag, | |
args: [%Nx.Tensor{type: {:c, 64}, shape: in_shape} = in_tensor] | |
}, | |
shape: out_shape | |
} = t | |
) do | |
mlir block: block do | |
in_tensor = gen_op(env, in_tensor) | |
out_tensor = | |
Bufferization.alloc_tensor(operand_segment_sizes: ODS.operand_segment_sizes([0, 0])) >>> | |
gen_type(t) | |
Linalg.generic [ | |
in_tensor, | |
out_tensor, | |
operand_segment_sizes: ODS.operand_segment_sizes([1, 1]), | |
indexing_maps: gen_indexing_maps(in_shape, out_shape), | |
iterator_types: gen_iterator_types(in_shape, out_shape) | |
] do | |
region do | |
block bb0(arg0 >>> Type.complex(Type.f32()), arg1 >>> Type.f(32)) do | |
%MLIR.Value{} = arg1 | |
im = Complex.im(arg0) >>> Type.f32() | |
Linalg.yield([im]) >>> [] | |
end | |
end | |
end >>> gen_type(t) | |
end | |
end | |
# unary linalg | |
def gen_op( | |
%Env{block: block} = env, | |
%Nx.Tensor{type: type, data: %Nx.Defn.Expr{op: op, args: [input]}} = t | |
) | |
when op in [ | |
:population_count, | |
:count_leading_zeros, | |
:cos, | |
:sin, | |
:sqrt, | |
:tan, | |
:erf, | |
:cbrt, | |
:expm1, | |
:log1p | |
] do | |
mlir block: block do | |
input_value = gen_op(env, input) | |
input_value = TOSA.cast(input_value) >>> gen_type(t) | |
out_tensor = | |
Bufferization.alloc_tensor(operand_segment_sizes: ODS.operand_segment_sizes([0, 0])) >>> | |
gen_type(t) | |
Linalg.generic [ | |
input_value, | |
out_tensor, | |
operand_segment_sizes: ODS.operand_segment_sizes([1, 1]), | |
indexing_maps: gen_indexing_maps(input.shape, t.shape), | |
iterator_types: gen_iterator_types(input.shape, t.shape) | |
] do | |
region do | |
block bb0(arg0 >>> gen_type(type), out >>> gen_type(type)) do | |
%MLIR.Value{} = out | |
result = | |
case op do | |
:population_count -> | |
Math.ctpop(arg0) >>> gen_type(type) | |
:count_leading_zeros -> | |
Math.ctlz(arg0) >>> gen_type(type) | |
:cos -> | |
Math.cos(arg0) >>> gen_type(type) | |
:sin -> | |
Math.sin(arg0) >>> gen_type(type) | |
:sqrt -> | |
Math.sqrt(arg0) >>> gen_type(type) | |
:tan -> | |
Math.tan(arg0) >>> gen_type(type) | |
:erf -> | |
Math.erf(arg0) >>> gen_type(type) | |
:cbrt -> | |
abs = Math.abs(arg0) >>> gen_type(type) | |
third = | |
Arith.constant(value: Attribute.float(gen_type(type), 0.333333343)) >>> | |
gen_type(type) | |
pow = Math.powf(abs, third) >>> gen_type(type) | |
Math.copysign(pow, arg0) >>> gen_type(type) | |
:expm1 -> | |
Math.expm1(arg0) >>> gen_type(type) | |
:log1p -> | |
Math.log1p(arg0) >>> gen_type(type) | |
end | |
Linalg.yield(result) >>> [] | |
end | |
end | |
end >>> gen_type(t) | |
end | |
end | |
# binary linalg | |
def gen_op( | |
%Env{block: block} = env, | |
%Nx.Tensor{type: type, data: %Nx.Defn.Expr{op: op, args: [a, b]}} = t | |
) | |
when op in [:remainder, :atan2] do | |
mlir block: block do | |
a_value = gen_op(env, a) | |
a_value = gen_expand(env, a_value, a, t) | |
b_value = gen_op(env, b) | |
b_value = gen_expand(env, b_value, b, t) | |
out_tensor = | |
Bufferization.alloc_tensor(operand_segment_sizes: ODS.operand_segment_sizes([0, 0])) >>> | |
gen_type(t) | |
Linalg.generic [ | |
a_value, | |
b_value, | |
out_tensor, | |
operand_segment_sizes: ODS.operand_segment_sizes([2, 1]), | |
indexing_maps: gen_indexing_maps(a.shape, b.shape, t.shape), | |
iterator_types: gen_iterator_types(a.shape, b.shape, t.shape) | |
] do | |
region do | |
block bb0(arg0 >>> gen_type(type), arg1 >>> gen_type(type), out >>> gen_type(type)) do | |
%MLIR.Value{} = out | |
result = | |
case op do | |
:remainder -> | |
case type do | |
{:f, _} -> | |
Arith.remf(arg0, arg1) >>> gen_type(type) | |
{:i, _} -> | |
Arith.remui(arg0, arg1) >>> gen_type(type) | |
{:s, _} -> | |
Arith.remsi(arg0, arg1) >>> gen_type(type) | |
end | |
:atan2 -> | |
Math.atan2(arg0, arg1) >>> gen_type(type) | |
end | |
Linalg.yield(result) >>> [] | |
end | |
end | |
end >>> gen_type(t) | |
end | |
end | |
def gen_op(env, %Nx.Tensor{ | |
data: %Nx.Defn.Expr{ | |
op: :optional, | |
args: | |
[ | |
%{ | |
data: %{op: :logical_not} | |
}, | |
%{ | |
data: %{op: :equal} | |
} | |
] = list | |
} | |
}) do | |
gen_op(env, List.first(list)) | |
end | |
# binary tosa | |
def gen_op( | |
%Env{block: block} = env, | |
%Nx.Tensor{data: %Nx.Defn.Expr{op: op, args: [a, b]}} = t | |
) do | |
mlir block: block do | |
a_t = %{a | type: t.type} |> gen_type | |
b_t = %{b | type: t.type} |> gen_type | |
a_value = gen_op(env, a) | |
b_value = gen_op(env, b) | |
{a_value, b_value} = | |
case op do | |
_ when op in [:equal] -> | |
b_value = | |
if a.type != b.type do | |
TOSA.cast(b_value) >>> gen_type(%{b | type: a.type}) | |
else | |
b_value | |
end | |
{a_value, b_value} | |
_ when op in [:logical_or, :logical_xor, :logical_and] -> | |
a_value = TOSA.cast(a_value) >>> gen_type(%{a | type: {:u, 1}}) | |
b_value = TOSA.cast(b_value) >>> gen_type(%{b | type: {:u, 1}}) | |
{a_value, b_value} | |
_ -> | |
a_value = TOSA.cast(a_value) >>> a_t | |
b_value = TOSA.cast(b_value) >>> b_t | |
{a_value, b_value} | |
end | |
case op do | |
:subtract -> | |
TOSA.sub(a_value, b_value) >>> gen_type(t) | |
:less_equal -> | |
c = TOSA.greater_equal(b_value, a_value) >>> gen_type(%{t | type: {:u, 1}}) | |
TOSA.cast(c) >>> gen_type(t) | |
:greater_equal -> | |
c = TOSA.greater_equal(a_value, b_value) >>> gen_type(%{t | type: {:u, 1}}) | |
TOSA.cast(c) >>> gen_type(t) | |
:less -> | |
c = TOSA.greater(b_value, a_value) >>> gen_type(%{t | type: {:u, 1}}) | |
TOSA.cast(c) >>> gen_type(t) | |
:greater -> | |
c = TOSA.greater(a_value, b_value) >>> gen_type(%{t | type: {:u, 1}}) | |
TOSA.cast(c) >>> gen_type(t) | |
:equal -> | |
c = TOSA.equal(b_value, a_value) >>> gen_type(%{t | type: {:u, 1}}) | |
TOSA.cast(c) >>> gen_type(t) | |
:not_equal -> | |
c = TOSA.equal(b_value, a_value) >>> gen_type(%{t | type: {:u, 1}}) | |
c = TOSA.logical_not(c) >>> gen_type(%{t | type: {:u, 1}}) | |
TOSA.cast(c) >>> gen_type(t) | |
:logical_and -> | |
c = TOSA.logical_and(a_value, b_value) >>> gen_type(%{t | type: {:u, 1}}) | |
TOSA.cast(c) >>> gen_type(t) | |
:logical_or -> | |
c = TOSA.logical_or(a_value, b_value) >>> gen_type(%{t | type: {:u, 1}}) | |
TOSA.cast(c) >>> gen_type(t) | |
:logical_xor -> | |
c = TOSA.logical_xor(a_value, b_value) >>> gen_type(%{t | type: {:u, 1}}) | |
TOSA.cast(c) >>> gen_type(t) | |
:add -> | |
TOSA.add(a_value, b_value) >>> gen_type(t) | |
:max -> | |
TOSA.maximum(a_value, b_value) >>> gen_type(t) | |
:min -> | |
TOSA.minimum(a_value, b_value) >>> gen_type(t) | |
:bitwise_and -> | |
TOSA.bitwise_and(a_value, b_value) >>> gen_type(t) | |
:bitwise_or -> | |
TOSA.bitwise_or(a_value, b_value) >>> gen_type(t) | |
:left_shift -> | |
TOSA.logical_left_shift(a_value, b_value) >>> gen_type(t) | |
:right_shift -> | |
case t.type do | |
{:u, _} -> | |
TOSA.logical_right_shift(a_value, b_value) >>> gen_type(t) | |
{:s, _} -> | |
TOSA.arithmetic_right_shift(a_value, b_value, round: Attribute.bool(false)) >>> | |
gen_type(t) | |
end | |
:multiply -> | |
TOSA.mul(a_value, b_value, shift: Attribute.integer(Type.i32(), 0)) >>> gen_type(t) | |
:divide -> | |
b_r = TOSA.reciprocal(b_value) >>> b_t | |
TOSA.mul(a_value, b_r, shift: Attribute.integer(Type.i(32), 0)) >>> gen_type(t) | |
:quotient -> | |
a_value = TOSA.cast(a_value) >>> gen_type(%{a | type: {:u, 32}}) | |
b_value = TOSA.cast(b_value) >>> gen_type(%{b | type: {:u, 32}}) | |
result = TOSA.div(a_value, b_value) >>> gen_type(%{t | type: {:u, 32}}) | |
TOSA.cast(result) >>> gen_type(t) | |
:power -> | |
{_, width} = a.type | |
width = min(width, 32) | |
a_value = TOSA.cast(a_value) >>> gen_type(%{a | type: {:f, width}}) | |
b_value = TOSA.cast(b_value) >>> gen_type(%{b | type: {:f, width}}) | |
result = TOSA.pow(a_value, b_value) >>> gen_type(%{t | type: {:f, width}}) | |
TOSA.cast(result) >>> gen_type(t) | |
_ -> | |
raise "Unsupported binary op: #{inspect(t, structs: false, pretty: true)}" | |
end | |
end | |
end | |
def gen_op( | |
%Env{block: block} = env, | |
%Nx.Tensor{data: %Nx.Defn.Expr{op: :select, args: [pred, on_true, on_false]}} = t | |
) do | |
mlir block: block do | |
pred_value = gen_op(env, pred) | |
pred_t = %{pred | type: {:u, 1}} | |
pred_value = TOSA.cast(pred_value) >>> gen_type(pred_t) | |
pred_value = gen_expand(env, pred_value, pred_t, t) | |
on_true_value = gen_op(env, on_true) | |
on_false_value = gen_op(env, on_false) | |
on_true_value = TOSA.cast(on_true_value) >>> gen_type(%{on_true | type: t.type}) | |
on_false_value = TOSA.cast(on_false_value) >>> gen_type(%{on_false | type: t.type}) | |
TOSA.select(pred_value, on_true_value, on_false_value) >>> gen_type(t) | |
end | |
end | |
def gen_op(%Env{} = env, tuple) when is_tuple(tuple) do | |
tuple | |
|> Tuple.to_list() | |
|> Enum.map(&gen_op(env, &1)) | |
|> List.to_tuple() | |
end | |
def gen_op(_, tensor) do | |
raise "op not supported: " <> inspect(tensor, structs: false, pretty: true) | |
end | |
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
defmodule CFTest do | |
use ExUnit.Case | |
use Beaver | |
alias Beaver.MLIR | |
alias Beaver.MLIR.{Attribute, Type} | |
import ExUnit.CaptureIO | |
test "generate mlir with function calls" do | |
ir = | |
mlir do | |
module do | |
Func.func some_func(function_type: Type.function([], [Type.i(32)])) do | |
region do | |
block bb_entry() do | |
v0 = Arith.constant(value: Attribute.integer(Type.i(32), 0)) >>> Type.i(32) | |
cond0 = Arith.constant(true) >>> Type.i(1) | |
CF.cond_br(cond0, MLIR.__BLOCK__(bb1), {MLIR.__BLOCK__(bb2), [v0]}) >>> [] | |
end | |
block bb1() do | |
v1 = Arith.constant(value: Attribute.integer(Type.i(32), 0)) >>> Type.i(32) | |
_add = Arith.addi(v0, v0) >>> Type.i(32) | |
CF.br({MLIR.__BLOCK__(bb2), [v1]}) >>> [] | |
end | |
block bb2(arg >>> Type.i(32)) do | |
v2 = Arith.constant(value: Attribute.integer(Type.i(32), 0)) >>> Type.i(32) | |
add = Arith.addi(arg, v2) >>> Type.i(32) | |
Func.return(add) >>> [] | |
end | |
end | |
end | |
|> MLIR.Operation.verify!(dump_if_fail: true) | |
Func.func some_func2(function_type: Type.function([], [Type.i(32)])) do | |
region do | |
block bb_entry() do | |
v0 = Arith.constant(value: Attribute.integer(Type.i(32), 0)) >>> Type.i(32) | |
_add = Arith.addi(v0, v0) >>> Type.i(32) | |
CF.br({MLIR.__BLOCK__(bb1), [v0]}) >>> [] | |
end | |
block bb1(arg >>> Type.i(32)) do | |
v2 = Arith.constant(value: Attribute.integer(Type.i(32), 0)) >>> Type.i(32) | |
add = Arith.addi(arg, v2) >>> Type.i(32) | |
_sub = Arith.subi(arg, v2) >>> Type.i(32) | |
_mul = Arith.muli(arg, v2) >>> Type.i(32) | |
_div = Arith.divsi(arg, v2) >>> Type.i(32) | |
Func.return(add) >>> [] | |
end | |
end | |
end | |
end | |
end | |
|> MLIR.Operation.verify!() | |
captured = | |
capture_io(fn -> | |
IO.inspect(ir) | |
end) | |
assert captured =~ ~r"module {" | |
assert captured =~ ~r"// pred.+bb0" | |
assert captured =~ ~r"// 2 preds.+bb0.+bb1" | |
end | |
end |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment