Skip to content

Instantly share code, notes, and snippets.

@llandsmeer
Last active May 14, 2025 16:00
Show Gist options
  • Select an option

  • Save llandsmeer/240f5718350fef6feca57010ba4888b3 to your computer and use it in GitHub Desktop.

Select an option

Save llandsmeer/240f5718350fef6feca57010ba4888b3 to your computer and use it in GitHub Desktop.
StreamHLS with JAX
set -ex
. ../env/bin/activate
python3 jax_test.py
iree-opt ./jax.stablehlo.mlir \
--iree-stablehlo-input-transformation-pipeline \
--convert-scf-to-cf \
> jax.linalg.mlir
streamhls-opt ./jax.linalg.mlir \
-streamhls-host-pipeline \
> host.mlir
streamhls-translate ./host.mlir \
-emit-vivado-hls \
-vitis-hls-weights-dir=data \
-vitis-hls-is-host=true \
-o host_tb.cpp
# put top-func=main
streamhls-opt jax.linalg.mlir \
-streamhls-kernel-pipeline="top-func=main \
graph-file=graph\
report-file=report\
optimize-schedule=1\
parallelize-nodes=1\
combined-optimization=0\
board-dsps=1024 \
tiling-limit=16 \
time-limit-minutes=10 \
bufferize-func-args=0 \
optimize-conv-reuse=0 \
minimize-on-chip-buffers=0 \
debug-point=14" > kernel.mlir
streamhls-translate \
kernel.mlir \
-emit-vivado-hls \
-o kernel.cpp
from jax import numpy as jnp
from jax import jit
import jax
def f(rs):
o = jax.vmap(lambda r:
jax.lax.scan(lambda x, _:
(jax.lax.select(
x > 0.5,
-x,
r*x*(1-x)),
x)
, 0.5, length=20, unroll=True)[0]
)(rs)
return o
# return jax.lax.select(x > 3, x - 5, x + 5)
r# eturn x + 1
x = jnp.ones((4,))
mlir = jax.jit(f).lower(x).as_text()
with open('jax.stablehlo.mlir', 'w') as f:
print(mlir.replace('@main(', '@forward('), file=f)
print(mlir)
#map = affine_map<(d0) -> (d0)>
module @jit_f {
func.func public @forward(%arg0: tensor<4xf32>) -> tensor<4xf32> {
%cst = arith.constant -5.000000e-01 : f32
%false = arith.constant false
%cst_0 = arith.constant 1.000000e+00 : f32
%cst_1 = arith.constant 5.000000e-01 : f32
%0 = tensor.empty() : tensor<4xi1>
%1 = linalg.generic {indexing_maps = [#map], iterator_types = ["parallel"]} outs(%0 : tensor<4xi1>) {
^bb0(%out: i1):
linalg.yield %false : i1
} -> tensor<4xi1>
%2 = tensor.empty() : tensor<4xf32>
%3 = linalg.generic {indexing_maps = [#map], iterator_types = ["parallel"]} outs(%2 : tensor<4xf32>) {
^bb0(%out: f32):
linalg.yield %cst : f32
} -> tensor<4xf32>
%4 = tensor.empty() : tensor<4xf32>
%5 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel"]} ins(%arg0 : tensor<4xf32>) outs(%4 : tensor<4xf32>) {
^bb0(%in: f32, %out: f32):
%238 = arith.mulf %in, %cst_1 : f32
linalg.yield %238 : f32
} -> tensor<4xf32>
%6 = tensor.empty() : tensor<4xf32>
%7 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel"]} ins(%5 : tensor<4xf32>) outs(%6 : tensor<4xf32>) {
^bb0(%in: f32, %out: f32):
%238 = arith.mulf %in, %cst_1 : f32
linalg.yield %238 : f32
} -> tensor<4xf32>
%8 = tensor.empty() : tensor<4xf32>
%9 = linalg.generic {indexing_maps = [#map, #map, #map, #map], iterator_types = ["parallel"]} ins(%1, %3, %7 : tensor<4xi1>, tensor<4xf32>, tensor<4xf32>) outs(%8 : tensor<4xf32>) {
^bb0(%in: i1, %in_2: f32, %in_3: f32, %out: f32):
%238 = arith.select %in, %in_2, %in_3 : f32
linalg.yield %238 : f32
} -> tensor<4xf32>
%10 = tensor.empty() : tensor<4xi1>
%11 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel"]} ins(%9 : tensor<4xf32>) outs(%10 : tensor<4xi1>) {
^bb0(%in: f32, %out: i1):
%238 = arith.cmpf ogt, %in, %cst_1 : f32
linalg.yield %238 : i1
} -> tensor<4xi1>
%12 = tensor.empty() : tensor<4xf32>
%13 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel"]} ins(%9 : tensor<4xf32>) outs(%12 : tensor<4xf32>) {
^bb0(%in: f32, %out: f32):
%238 = arith.negf %in : f32
linalg.yield %238 : f32
} -> tensor<4xf32>
%14 = tensor.empty() : tensor<4xf32>
%15 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins(%arg0, %9 : tensor<4xf32>, tensor<4xf32>) outs(%14 : tensor<4xf32>) {
^bb0(%in: f32, %in_2: f32, %out: f32):
%238 = arith.mulf %in, %in_2 : f32
linalg.yield %238 : f32
} -> tensor<4xf32>
%16 = tensor.empty() : tensor<4xf32>
%17 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel"]} ins(%9 : tensor<4xf32>) outs(%16 : tensor<4xf32>) {
^bb0(%in: f32, %out: f32):
%238 = arith.subf %cst_0, %in : f32
linalg.yield %238 : f32
} -> tensor<4xf32>
%18 = tensor.empty() : tensor<4xf32>
%19 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins(%15, %17 : tensor<4xf32>, tensor<4xf32>) outs(%18 : tensor<4xf32>) {
^bb0(%in: f32, %in_2: f32, %out: f32):
%238 = arith.mulf %in, %in_2 : f32
linalg.yield %238 : f32
} -> tensor<4xf32>
%20 = tensor.empty() : tensor<4xf32>
%21 = linalg.generic {indexing_maps = [#map, #map, #map, #map], iterator_types = ["parallel"]} ins(%11, %13, %19 : tensor<4xi1>, tensor<4xf32>, tensor<4xf32>) outs(%20 : tensor<4xf32>) {
^bb0(%in: i1, %in_2: f32, %in_3: f32, %out: f32):
%238 = arith.select %in, %in_2, %in_3 : f32
linalg.yield %238 : f32
} -> tensor<4xf32>
%22 = tensor.empty() : tensor<4xi1>
%23 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel"]} ins(%21 : tensor<4xf32>) outs(%22 : tensor<4xi1>) {
^bb0(%in: f32, %out: i1):
%238 = arith.cmpf ogt, %in, %cst_1 : f32
linalg.yield %238 : i1
} -> tensor<4xi1>
%24 = tensor.empty() : tensor<4xf32>
%25 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel"]} ins(%21 : tensor<4xf32>) outs(%24 : tensor<4xf32>) {
^bb0(%in: f32, %out: f32):
%238 = arith.negf %in : f32
linalg.yield %238 : f32
} -> tensor<4xf32>
%26 = tensor.empty() : tensor<4xf32>
%27 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins(%arg0, %21 : tensor<4xf32>, tensor<4xf32>) outs(%26 : tensor<4xf32>) {
^bb0(%in: f32, %in_2: f32, %out: f32):
%238 = arith.mulf %in, %in_2 : f32
linalg.yield %238 : f32
} -> tensor<4xf32>
%28 = tensor.empty() : tensor<4xf32>
%29 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel"]} ins(%21 : tensor<4xf32>) outs(%28 : tensor<4xf32>) {
^bb0(%in: f32, %out: f32):
%238 = arith.subf %cst_0, %in : f32
linalg.yield %238 : f32
} -> tensor<4xf32>
%30 = tensor.empty() : tensor<4xf32>
%31 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins(%27, %29 : tensor<4xf32>, tensor<4xf32>) outs(%30 : tensor<4xf32>) {
^bb0(%in: f32, %in_2: f32, %out: f32):
%238 = arith.mulf %in, %in_2 : f32
linalg.yield %238 : f32
} -> tensor<4xf32>
%32 = tensor.empty() : tensor<4xf32>
%33 = linalg.generic {indexing_maps = [#map, #map, #map, #map], iterator_types = ["parallel"]} ins(%23, %25, %31 : tensor<4xi1>, tensor<4xf32>, tensor<4xf32>) outs(%32 : tensor<4xf32>) {
^bb0(%in: i1, %in_2: f32, %in_3: f32, %out: f32):
%238 = arith.select %in, %in_2, %in_3 : f32
linalg.yield %238 : f32
} -> tensor<4xf32>
%34 = tensor.empty() : tensor<4xi1>
%35 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel"]} ins(%33 : tensor<4xf32>) outs(%34 : tensor<4xi1>) {
^bb0(%in: f32, %out: i1):
%238 = arith.cmpf ogt, %in, %cst_1 : f32
linalg.yield %238 : i1
} -> tensor<4xi1>
%36 = tensor.empty() : tensor<4xf32>
%37 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel"]} ins(%33 : tensor<4xf32>) outs(%36 : tensor<4xf32>) {
^bb0(%in: f32, %out: f32):
%238 = arith.negf %in : f32
linalg.yield %238 : f32
} -> tensor<4xf32>
%38 = tensor.empty() : tensor<4xf32>
%39 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins(%arg0, %33 : tensor<4xf32>, tensor<4xf32>) outs(%38 : tensor<4xf32>) {
^bb0(%in: f32, %in_2: f32, %out: f32):
%238 = arith.mulf %in, %in_2 : f32
linalg.yield %238 : f32
} -> tensor<4xf32>
%40 = tensor.empty() : tensor<4xf32>
%41 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel"]} ins(%33 : tensor<4xf32>) outs(%40 : tensor<4xf32>) {
^bb0(%in: f32, %out: f32):
%238 = arith.subf %cst_0, %in : f32
linalg.yield %238 : f32
} -> tensor<4xf32>
%42 = tensor.empty() : tensor<4xf32>
%43 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins(%39, %41 : tensor<4xf32>, tensor<4xf32>) outs(%42 : tensor<4xf32>) {
^bb0(%in: f32, %in_2: f32, %out: f32):
%238 = arith.mulf %in, %in_2 : f32
linalg.yield %238 : f32
} -> tensor<4xf32>
%44 = tensor.empty() : tensor<4xf32>
%45 = linalg.generic {indexing_maps = [#map, #map, #map, #map], iterator_types = ["parallel"]} ins(%35, %37, %43 : tensor<4xi1>, tensor<4xf32>, tensor<4xf32>) outs(%44 : tensor<4xf32>) {
^bb0(%in: i1, %in_2: f32, %in_3: f32, %out: f32):
%238 = arith.select %in, %in_2, %in_3 : f32
linalg.yield %238 : f32
} -> tensor<4xf32>
%46 = tensor.empty() : tensor<4xi1>
%47 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel"]} ins(%45 : tensor<4xf32>) outs(%46 : tensor<4xi1>) {
^bb0(%in: f32, %out: i1):
%238 = arith.cmpf ogt, %in, %cst_1 : f32
linalg.yield %238 : i1
} -> tensor<4xi1>
%48 = tensor.empty() : tensor<4xf32>
%49 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel"]} ins(%45 : tensor<4xf32>) outs(%48 : tensor<4xf32>) {
^bb0(%in: f32, %out: f32):
%238 = arith.negf %in : f32
linalg.yield %238 : f32
} -> tensor<4xf32>
%50 = tensor.empty() : tensor<4xf32>
%51 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins(%arg0, %45 : tensor<4xf32>, tensor<4xf32>) outs(%50 : tensor<4xf32>) {
^bb0(%in: f32, %in_2: f32, %out: f32):
%238 = arith.mulf %in, %in_2 : f32
linalg.yield %238 : f32
} -> tensor<4xf32>
%52 = tensor.empty() : tensor<4xf32>
%53 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel"]} ins(%45 : tensor<4xf32>) outs(%52 : tensor<4xf32>) {
^bb0(%in: f32, %out: f32):
%238 = arith.subf %cst_0, %in : f32
linalg.yield %238 : f32
} -> tensor<4xf32>
%54 = tensor.empty() : tensor<4xf32>
%55 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins(%51, %53 : tensor<4xf32>, tensor<4xf32>) outs(%54 : tensor<4xf32>) {
^bb0(%in: f32, %in_2: f32, %out: f32):
%238 = arith.mulf %in, %in_2 : f32
linalg.yield %238 : f32
} -> tensor<4xf32>
%56 = tensor.empty() : tensor<4xf32>
%57 = linalg.generic {indexing_maps = [#map, #map, #map, #map], iterator_types = ["parallel"]} ins(%47, %49, %55 : tensor<4xi1>, tensor<4xf32>, tensor<4xf32>) outs(%56 : tensor<4xf32>) {
^bb0(%in: i1, %in_2: f32, %in_3: f32, %out: f32):
%238 = arith.select %in, %in_2, %in_3 : f32
linalg.yield %238 : f32
} -> tensor<4xf32>
%58 = tensor.empty() : tensor<4xi1>
%59 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel"]} ins(%57 : tensor<4xf32>) outs(%58 : tensor<4xi1>) {
^bb0(%in: f32, %out: i1):
%238 = arith.cmpf ogt, %in, %cst_1 : f32
linalg.yield %238 : i1
} -> tensor<4xi1>
%60 = tensor.empty() : tensor<4xf32>
%61 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel"]} ins(%57 : tensor<4xf32>) outs(%60 : tensor<4xf32>) {
^bb0(%in: f32, %out: f32):
%238 = arith.negf %in : f32
linalg.yield %238 : f32
} -> tensor<4xf32>
%62 = tensor.empty() : tensor<4xf32>
%63 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins(%arg0, %57 : tensor<4xf32>, tensor<4xf32>) outs(%62 : tensor<4xf32>) {
^bb0(%in: f32, %in_2: f32, %out: f32):
%238 = arith.mulf %in, %in_2 : f32
linalg.yield %238 : f32
} -> tensor<4xf32>
%64 = tensor.empty() : tensor<4xf32>
%65 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel"]} ins(%57 : tensor<4xf32>) outs(%64 : tensor<4xf32>) {
^bb0(%in: f32, %out: f32):
%238 = arith.subf %cst_0, %in : f32
linalg.yield %238 : f32
} -> tensor<4xf32>
%66 = tensor.empty() : tensor<4xf32>
%67 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins(%63, %65 : tensor<4xf32>, tensor<4xf32>) outs(%66 : tensor<4xf32>) {
^bb0(%in: f32, %in_2: f32, %out: f32):
%238 = arith.mulf %in, %in_2 : f32
linalg.yield %238 : f32
} -> tensor<4xf32>
%68 = tensor.empty() : tensor<4xf32>
%69 = linalg.generic {indexing_maps = [#map, #map, #map, #map], iterator_types = ["parallel"]} ins(%59, %61, %67 : tensor<4xi1>, tensor<4xf32>, tensor<4xf32>) outs(%68 : tensor<4xf32>) {
^bb0(%in: i1, %in_2: f32, %in_3: f32, %out: f32):
%238 = arith.select %in, %in_2, %in_3 : f32
linalg.yield %238 : f32
} -> tensor<4xf32>
%70 = tensor.empty() : tensor<4xi1>
%71 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel"]} ins(%69 : tensor<4xf32>) outs(%70 : tensor<4xi1>) {
^bb0(%in: f32, %out: i1):
%238 = arith.cmpf ogt, %in, %cst_1 : f32
linalg.yield %238 : i1
} -> tensor<4xi1>
%72 = tensor.empty() : tensor<4xf32>
%73 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel"]} ins(%69 : tensor<4xf32>) outs(%72 : tensor<4xf32>) {
^bb0(%in: f32, %out: f32):
%238 = arith.negf %in : f32
linalg.yield %238 : f32
} -> tensor<4xf32>
%74 = tensor.empty() : tensor<4xf32>
%75 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins(%arg0, %69 : tensor<4xf32>, tensor<4xf32>) outs(%74 : tensor<4xf32>) {
^bb0(%in: f32, %in_2: f32, %out: f32):
%238 = arith.mulf %in, %in_2 : f32
linalg.yield %238 : f32
} -> tensor<4xf32>
%76 = tensor.empty() : tensor<4xf32>
%77 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel"]} ins(%69 : tensor<4xf32>) outs(%76 : tensor<4xf32>) {
^bb0(%in: f32, %out: f32):
%238 = arith.subf %cst_0, %in : f32
linalg.yield %238 : f32
} -> tensor<4xf32>
%78 = tensor.empty() : tensor<4xf32>
%79 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins(%75, %77 : tensor<4xf32>, tensor<4xf32>) outs(%78 : tensor<4xf32>) {
^bb0(%in: f32, %in_2: f32, %out: f32):
%238 = arith.mulf %in, %in_2 : f32
linalg.yield %238 : f32
} -> tensor<4xf32>
%80 = tensor.empty() : tensor<4xf32>
%81 = linalg.generic {indexing_maps = [#map, #map, #map, #map], iterator_types = ["parallel"]} ins(%71, %73, %79 : tensor<4xi1>, tensor<4xf32>, tensor<4xf32>) outs(%80 : tensor<4xf32>) {
^bb0(%in: i1, %in_2: f32, %in_3: f32, %out: f32):
%238 = arith.select %in, %in_2, %in_3 : f32
linalg.yield %238 : f32
} -> tensor<4xf32>
%82 = tensor.empty() : tensor<4xi1>
%83 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel"]} ins(%81 : tensor<4xf32>) outs(%82 : tensor<4xi1>) {
^bb0(%in: f32, %out: i1):
%238 = arith.cmpf ogt, %in, %cst_1 : f32
linalg.yield %238 : i1
} -> tensor<4xi1>
%84 = tensor.empty() : tensor<4xf32>
%85 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel"]} ins(%81 : tensor<4xf32>) outs(%84 : tensor<4xf32>) {
^bb0(%in: f32, %out: f32):
%238 = arith.negf %in : f32
linalg.yield %238 : f32
} -> tensor<4xf32>
%86 = tensor.empty() : tensor<4xf32>
%87 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins(%arg0, %81 : tensor<4xf32>, tensor<4xf32>) outs(%86 : tensor<4xf32>) {
^bb0(%in: f32, %in_2: f32, %out: f32):
%238 = arith.mulf %in, %in_2 : f32
linalg.yield %238 : f32
} -> tensor<4xf32>
%88 = tensor.empty() : tensor<4xf32>
%89 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel"]} ins(%81 : tensor<4xf32>) outs(%88 : tensor<4xf32>) {
^bb0(%in: f32, %out: f32):
%238 = arith.subf %cst_0, %in : f32
linalg.yield %238 : f32
} -> tensor<4xf32>
%90 = tensor.empty() : tensor<4xf32>
%91 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins(%87, %89 : tensor<4xf32>, tensor<4xf32>) outs(%90 : tensor<4xf32>) {
^bb0(%in: f32, %in_2: f32, %out: f32):
%238 = arith.mulf %in, %in_2 : f32
linalg.yield %238 : f32
} -> tensor<4xf32>
%92 = tensor.empty() : tensor<4xf32>
%93 = linalg.generic {indexing_maps = [#map, #map, #map, #map], iterator_types = ["parallel"]} ins(%83, %85, %91 : tensor<4xi1>, tensor<4xf32>, tensor<4xf32>) outs(%92 : tensor<4xf32>) {
^bb0(%in: i1, %in_2: f32, %in_3: f32, %out: f32):
%238 = arith.select %in, %in_2, %in_3 : f32
linalg.yield %238 : f32
} -> tensor<4xf32>
%94 = tensor.empty() : tensor<4xi1>
%95 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel"]} ins(%93 : tensor<4xf32>) outs(%94 : tensor<4xi1>) {
^bb0(%in: f32, %out: i1):
%238 = arith.cmpf ogt, %in, %cst_1 : f32
linalg.yield %238 : i1
} -> tensor<4xi1>
%96 = tensor.empty() : tensor<4xf32>
%97 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel"]} ins(%93 : tensor<4xf32>) outs(%96 : tensor<4xf32>) {
^bb0(%in: f32, %out: f32):
%238 = arith.negf %in : f32
linalg.yield %238 : f32
} -> tensor<4xf32>
%98 = tensor.empty() : tensor<4xf32>
%99 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins(%arg0, %93 : tensor<4xf32>, tensor<4xf32>) outs(%98 : tensor<4xf32>) {
^bb0(%in: f32, %in_2: f32, %out: f32):
%238 = arith.mulf %in, %in_2 : f32
linalg.yield %238 : f32
} -> tensor<4xf32>
%100 = tensor.empty() : tensor<4xf32>
%101 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel"]} ins(%93 : tensor<4xf32>) outs(%100 : tensor<4xf32>) {
^bb0(%in: f32, %out: f32):
%238 = arith.subf %cst_0, %in : f32
linalg.yield %238 : f32
} -> tensor<4xf32>
%102 = tensor.empty() : tensor<4xf32>
%103 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins(%99, %101 : tensor<4xf32>, tensor<4xf32>) outs(%102 : tensor<4xf32>) {
^bb0(%in: f32, %in_2: f32, %out: f32):
%238 = arith.mulf %in, %in_2 : f32
linalg.yield %238 : f32
} -> tensor<4xf32>
%104 = tensor.empty() : tensor<4xf32>
%105 = linalg.generic {indexing_maps = [#map, #map, #map, #map], iterator_types = ["parallel"]} ins(%95, %97, %103 : tensor<4xi1>, tensor<4xf32>, tensor<4xf32>) outs(%104 : tensor<4xf32>) {
^bb0(%in: i1, %in_2: f32, %in_3: f32, %out: f32):
%238 = arith.select %in, %in_2, %in_3 : f32
linalg.yield %238 : f32
} -> tensor<4xf32>
%106 = tensor.empty() : tensor<4xi1>
%107 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel"]} ins(%105 : tensor<4xf32>) outs(%106 : tensor<4xi1>) {
^bb0(%in: f32, %out: i1):
%238 = arith.cmpf ogt, %in, %cst_1 : f32
linalg.yield %238 : i1
} -> tensor<4xi1>
%108 = tensor.empty() : tensor<4xf32>
%109 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel"]} ins(%105 : tensor<4xf32>) outs(%108 : tensor<4xf32>) {
^bb0(%in: f32, %out: f32):
%238 = arith.negf %in : f32
linalg.yield %238 : f32
} -> tensor<4xf32>
%110 = tensor.empty() : tensor<4xf32>
%111 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins(%arg0, %105 : tensor<4xf32>, tensor<4xf32>) outs(%110 : tensor<4xf32>) {
^bb0(%in: f32, %in_2: f32, %out: f32):
%238 = arith.mulf %in, %in_2 : f32
linalg.yield %238 : f32
} -> tensor<4xf32>
%112 = tensor.empty() : tensor<4xf32>
%113 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel"]} ins(%105 : tensor<4xf32>) outs(%112 : tensor<4xf32>) {
^bb0(%in: f32, %out: f32):
%238 = arith.subf %cst_0, %in : f32
linalg.yield %238 : f32
} -> tensor<4xf32>
%114 = tensor.empty() : tensor<4xf32>
%115 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins(%111, %113 : tensor<4xf32>, tensor<4xf32>) outs(%114 : tensor<4xf32>) {
^bb0(%in: f32, %in_2: f32, %out: f32):
%238 = arith.mulf %in, %in_2 : f32
linalg.yield %238 : f32
} -> tensor<4xf32>
%116 = tensor.empty() : tensor<4xf32>
%117 = linalg.generic {indexing_maps = [#map, #map, #map, #map], iterator_types = ["parallel"]} ins(%107, %109, %115 : tensor<4xi1>, tensor<4xf32>, tensor<4xf32>) outs(%116 : tensor<4xf32>) {
^bb0(%in: i1, %in_2: f32, %in_3: f32, %out: f32):
%238 = arith.select %in, %in_2, %in_3 : f32
linalg.yield %238 : f32
} -> tensor<4xf32>
%118 = tensor.empty() : tensor<4xi1>
%119 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel"]} ins(%117 : tensor<4xf32>) outs(%118 : tensor<4xi1>) {
^bb0(%in: f32, %out: i1):
%238 = arith.cmpf ogt, %in, %cst_1 : f32
linalg.yield %238 : i1
} -> tensor<4xi1>
%120 = tensor.empty() : tensor<4xf32>
%121 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel"]} ins(%117 : tensor<4xf32>) outs(%120 : tensor<4xf32>) {
^bb0(%in: f32, %out: f32):
%238 = arith.negf %in : f32
linalg.yield %238 : f32
} -> tensor<4xf32>
%122 = tensor.empty() : tensor<4xf32>
%123 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins(%arg0, %117 : tensor<4xf32>, tensor<4xf32>) outs(%122 : tensor<4xf32>) {
^bb0(%in: f32, %in_2: f32, %out: f32):
%238 = arith.mulf %in, %in_2 : f32
linalg.yield %238 : f32
} -> tensor<4xf32>
%124 = tensor.empty() : tensor<4xf32>
%125 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel"]} ins(%117 : tensor<4xf32>) outs(%124 : tensor<4xf32>) {
^bb0(%in: f32, %out: f32):
%238 = arith.subf %cst_0, %in : f32
linalg.yield %238 : f32
} -> tensor<4xf32>
%126 = tensor.empty() : tensor<4xf32>
%127 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins(%123, %125 : tensor<4xf32>, tensor<4xf32>) outs(%126 : tensor<4xf32>) {
^bb0(%in: f32, %in_2: f32, %out: f32):
%238 = arith.mulf %in, %in_2 : f32
linalg.yield %238 : f32
} -> tensor<4xf32>
%128 = tensor.empty() : tensor<4xf32>
%129 = linalg.generic {indexing_maps = [#map, #map, #map, #map], iterator_types = ["parallel"]} ins(%119, %121, %127 : tensor<4xi1>, tensor<4xf32>, tensor<4xf32>) outs(%128 : tensor<4xf32>) {
^bb0(%in: i1, %in_2: f32, %in_3: f32, %out: f32):
%238 = arith.select %in, %in_2, %in_3 : f32
linalg.yield %238 : f32
} -> tensor<4xf32>
%130 = tensor.empty() : tensor<4xi1>
%131 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel"]} ins(%129 : tensor<4xf32>) outs(%130 : tensor<4xi1>) {
^bb0(%in: f32, %out: i1):
%238 = arith.cmpf ogt, %in, %cst_1 : f32
linalg.yield %238 : i1
} -> tensor<4xi1>
%132 = tensor.empty() : tensor<4xf32>
%133 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel"]} ins(%129 : tensor<4xf32>) outs(%132 : tensor<4xf32>) {
^bb0(%in: f32, %out: f32):
%238 = arith.negf %in : f32
linalg.yield %238 : f32
} -> tensor<4xf32>
%134 = tensor.empty() : tensor<4xf32>
%135 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins(%arg0, %129 : tensor<4xf32>, tensor<4xf32>) outs(%134 : tensor<4xf32>) {
^bb0(%in: f32, %in_2: f32, %out: f32):
%238 = arith.mulf %in, %in_2 : f32
linalg.yield %238 : f32
} -> tensor<4xf32>
%136 = tensor.empty() : tensor<4xf32>
%137 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel"]} ins(%129 : tensor<4xf32>) outs(%136 : tensor<4xf32>) {
^bb0(%in: f32, %out: f32):
%238 = arith.subf %cst_0, %in : f32
linalg.yield %238 : f32
} -> tensor<4xf32>
%138 = tensor.empty() : tensor<4xf32>
%139 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins(%135, %137 : tensor<4xf32>, tensor<4xf32>) outs(%138 : tensor<4xf32>) {
^bb0(%in: f32, %in_2: f32, %out: f32):
%238 = arith.mulf %in, %in_2 : f32
linalg.yield %238 : f32
} -> tensor<4xf32>
%140 = tensor.empty() : tensor<4xf32>
%141 = linalg.generic {indexing_maps = [#map, #map, #map, #map], iterator_types = ["parallel"]} ins(%131, %133, %139 : tensor<4xi1>, tensor<4xf32>, tensor<4xf32>) outs(%140 : tensor<4xf32>) {
^bb0(%in: i1, %in_2: f32, %in_3: f32, %out: f32):
%238 = arith.select %in, %in_2, %in_3 : f32
linalg.yield %238 : f32
} -> tensor<4xf32>
%142 = tensor.empty() : tensor<4xi1>
%143 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel"]} ins(%141 : tensor<4xf32>) outs(%142 : tensor<4xi1>) {
^bb0(%in: f32, %out: i1):
%238 = arith.cmpf ogt, %in, %cst_1 : f32
linalg.yield %238 : i1
} -> tensor<4xi1>
%144 = tensor.empty() : tensor<4xf32>
%145 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel"]} ins(%141 : tensor<4xf32>) outs(%144 : tensor<4xf32>) {
^bb0(%in: f32, %out: f32):
%238 = arith.negf %in : f32
linalg.yield %238 : f32
} -> tensor<4xf32>
%146 = tensor.empty() : tensor<4xf32>
%147 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins(%arg0, %141 : tensor<4xf32>, tensor<4xf32>) outs(%146 : tensor<4xf32>) {
^bb0(%in: f32, %in_2: f32, %out: f32):
%238 = arith.mulf %in, %in_2 : f32
linalg.yield %238 : f32
} -> tensor<4xf32>
%148 = tensor.empty() : tensor<4xf32>
%149 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel"]} ins(%141 : tensor<4xf32>) outs(%148 : tensor<4xf32>) {
^bb0(%in: f32, %out: f32):
%238 = arith.subf %cst_0, %in : f32
linalg.yield %238 : f32
} -> tensor<4xf32>
%150 = tensor.empty() : tensor<4xf32>
%151 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins(%147, %149 : tensor<4xf32>, tensor<4xf32>) outs(%150 : tensor<4xf32>) {
^bb0(%in: f32, %in_2: f32, %out: f32):
%238 = arith.mulf %in, %in_2 : f32
linalg.yield %238 : f32
} -> tensor<4xf32>
%152 = tensor.empty() : tensor<4xf32>
%153 = linalg.generic {indexing_maps = [#map, #map, #map, #map], iterator_types = ["parallel"]} ins(%143, %145, %151 : tensor<4xi1>, tensor<4xf32>, tensor<4xf32>) outs(%152 : tensor<4xf32>) {
^bb0(%in: i1, %in_2: f32, %in_3: f32, %out: f32):
%238 = arith.select %in, %in_2, %in_3 : f32
linalg.yield %238 : f32
} -> tensor<4xf32>
%154 = tensor.empty() : tensor<4xi1>
%155 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel"]} ins(%153 : tensor<4xf32>) outs(%154 : tensor<4xi1>) {
^bb0(%in: f32, %out: i1):
%238 = arith.cmpf ogt, %in, %cst_1 : f32
linalg.yield %238 : i1
} -> tensor<4xi1>
%156 = tensor.empty() : tensor<4xf32>
%157 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel"]} ins(%153 : tensor<4xf32>) outs(%156 : tensor<4xf32>) {
^bb0(%in: f32, %out: f32):
%238 = arith.negf %in : f32
linalg.yield %238 : f32
} -> tensor<4xf32>
%158 = tensor.empty() : tensor<4xf32>
%159 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins(%arg0, %153 : tensor<4xf32>, tensor<4xf32>) outs(%158 : tensor<4xf32>) {
^bb0(%in: f32, %in_2: f32, %out: f32):
%238 = arith.mulf %in, %in_2 : f32
linalg.yield %238 : f32
} -> tensor<4xf32>
%160 = tensor.empty() : tensor<4xf32>
%161 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel"]} ins(%153 : tensor<4xf32>) outs(%160 : tensor<4xf32>) {
^bb0(%in: f32, %out: f32):
%238 = arith.subf %cst_0, %in : f32
linalg.yield %238 : f32
} -> tensor<4xf32>
%162 = tensor.empty() : tensor<4xf32>
%163 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins(%159, %161 : tensor<4xf32>, tensor<4xf32>) outs(%162 : tensor<4xf32>) {
^bb0(%in: f32, %in_2: f32, %out: f32):
%238 = arith.mulf %in, %in_2 : f32
linalg.yield %238 : f32
} -> tensor<4xf32>
%164 = tensor.empty() : tensor<4xf32>
%165 = linalg.generic {indexing_maps = [#map, #map, #map, #map], iterator_types = ["parallel"]} ins(%155, %157, %163 : tensor<4xi1>, tensor<4xf32>, tensor<4xf32>) outs(%164 : tensor<4xf32>) {
^bb0(%in: i1, %in_2: f32, %in_3: f32, %out: f32):
%238 = arith.select %in, %in_2, %in_3 : f32
linalg.yield %238 : f32
} -> tensor<4xf32>
%166 = tensor.empty() : tensor<4xi1>
%167 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel"]} ins(%165 : tensor<4xf32>) outs(%166 : tensor<4xi1>) {
^bb0(%in: f32, %out: i1):
%238 = arith.cmpf ogt, %in, %cst_1 : f32
linalg.yield %238 : i1
} -> tensor<4xi1>
%168 = tensor.empty() : tensor<4xf32>
%169 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel"]} ins(%165 : tensor<4xf32>) outs(%168 : tensor<4xf32>) {
^bb0(%in: f32, %out: f32):
%238 = arith.negf %in : f32
linalg.yield %238 : f32
} -> tensor<4xf32>
%170 = tensor.empty() : tensor<4xf32>
%171 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins(%arg0, %165 : tensor<4xf32>, tensor<4xf32>) outs(%170 : tensor<4xf32>) {
^bb0(%in: f32, %in_2: f32, %out: f32):
%238 = arith.mulf %in, %in_2 : f32
linalg.yield %238 : f32
} -> tensor<4xf32>
%172 = tensor.empty() : tensor<4xf32>
%173 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel"]} ins(%165 : tensor<4xf32>) outs(%172 : tensor<4xf32>) {
^bb0(%in: f32, %out: f32):
%238 = arith.subf %cst_0, %in : f32
linalg.yield %238 : f32
} -> tensor<4xf32>
%174 = tensor.empty() : tensor<4xf32>
%175 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins(%171, %173 : tensor<4xf32>, tensor<4xf32>) outs(%174 : tensor<4xf32>) {
^bb0(%in: f32, %in_2: f32, %out: f32):
%238 = arith.mulf %in, %in_2 : f32
linalg.yield %238 : f32
} -> tensor<4xf32>
%176 = tensor.empty() : tensor<4xf32>
%177 = linalg.generic {indexing_maps = [#map, #map, #map, #map], iterator_types = ["parallel"]} ins(%167, %169, %175 : tensor<4xi1>, tensor<4xf32>, tensor<4xf32>) outs(%176 : tensor<4xf32>) {
^bb0(%in: i1, %in_2: f32, %in_3: f32, %out: f32):
%238 = arith.select %in, %in_2, %in_3 : f32
linalg.yield %238 : f32
} -> tensor<4xf32>
%178 = tensor.empty() : tensor<4xi1>
%179 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel"]} ins(%177 : tensor<4xf32>) outs(%178 : tensor<4xi1>) {
^bb0(%in: f32, %out: i1):
%238 = arith.cmpf ogt, %in, %cst_1 : f32
linalg.yield %238 : i1
} -> tensor<4xi1>
%180 = tensor.empty() : tensor<4xf32>
%181 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel"]} ins(%177 : tensor<4xf32>) outs(%180 : tensor<4xf32>) {
^bb0(%in: f32, %out: f32):
%238 = arith.negf %in : f32
linalg.yield %238 : f32
} -> tensor<4xf32>
%182 = tensor.empty() : tensor<4xf32>
%183 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins(%arg0, %177 : tensor<4xf32>, tensor<4xf32>) outs(%182 : tensor<4xf32>) {
^bb0(%in: f32, %in_2: f32, %out: f32):
%238 = arith.mulf %in, %in_2 : f32
linalg.yield %238 : f32
} -> tensor<4xf32>
%184 = tensor.empty() : tensor<4xf32>
%185 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel"]} ins(%177 : tensor<4xf32>) outs(%184 : tensor<4xf32>) {
^bb0(%in: f32, %out: f32):
%238 = arith.subf %cst_0, %in : f32
linalg.yield %238 : f32
} -> tensor<4xf32>
%186 = tensor.empty() : tensor<4xf32>
%187 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins(%183, %185 : tensor<4xf32>, tensor<4xf32>) outs(%186 : tensor<4xf32>) {
^bb0(%in: f32, %in_2: f32, %out: f32):
%238 = arith.mulf %in, %in_2 : f32
linalg.yield %238 : f32
} -> tensor<4xf32>
%188 = tensor.empty() : tensor<4xf32>
%189 = linalg.generic {indexing_maps = [#map, #map, #map, #map], iterator_types = ["parallel"]} ins(%179, %181, %187 : tensor<4xi1>, tensor<4xf32>, tensor<4xf32>) outs(%188 : tensor<4xf32>) {
^bb0(%in: i1, %in_2: f32, %in_3: f32, %out: f32):
%238 = arith.select %in, %in_2, %in_3 : f32
linalg.yield %238 : f32
} -> tensor<4xf32>
%190 = tensor.empty() : tensor<4xi1>
%191 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel"]} ins(%189 : tensor<4xf32>) outs(%190 : tensor<4xi1>) {
^bb0(%in: f32, %out: i1):
%238 = arith.cmpf ogt, %in, %cst_1 : f32
linalg.yield %238 : i1
} -> tensor<4xi1>
%192 = tensor.empty() : tensor<4xf32>
%193 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel"]} ins(%189 : tensor<4xf32>) outs(%192 : tensor<4xf32>) {
^bb0(%in: f32, %out: f32):
%238 = arith.negf %in : f32
linalg.yield %238 : f32
} -> tensor<4xf32>
%194 = tensor.empty() : tensor<4xf32>
%195 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins(%arg0, %189 : tensor<4xf32>, tensor<4xf32>) outs(%194 : tensor<4xf32>) {
^bb0(%in: f32, %in_2: f32, %out: f32):
%238 = arith.mulf %in, %in_2 : f32
linalg.yield %238 : f32
} -> tensor<4xf32>
%196 = tensor.empty() : tensor<4xf32>
%197 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel"]} ins(%189 : tensor<4xf32>) outs(%196 : tensor<4xf32>) {
^bb0(%in: f32, %out: f32):
%238 = arith.subf %cst_0, %in : f32
linalg.yield %238 : f32
} -> tensor<4xf32>
%198 = tensor.empty() : tensor<4xf32>
%199 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins(%195, %197 : tensor<4xf32>, tensor<4xf32>) outs(%198 : tensor<4xf32>) {
^bb0(%in: f32, %in_2: f32, %out: f32):
%238 = arith.mulf %in, %in_2 : f32
linalg.yield %238 : f32
} -> tensor<4xf32>
%200 = tensor.empty() : tensor<4xf32>
%201 = linalg.generic {indexing_maps = [#map, #map, #map, #map], iterator_types = ["parallel"]} ins(%191, %193, %199 : tensor<4xi1>, tensor<4xf32>, tensor<4xf32>) outs(%200 : tensor<4xf32>) {
^bb0(%in: i1, %in_2: f32, %in_3: f32, %out: f32):
%238 = arith.select %in, %in_2, %in_3 : f32
linalg.yield %238 : f32
} -> tensor<4xf32>
%202 = tensor.empty() : tensor<4xi1>
%203 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel"]} ins(%201 : tensor<4xf32>) outs(%202 : tensor<4xi1>) {
^bb0(%in: f32, %out: i1):
%238 = arith.cmpf ogt, %in, %cst_1 : f32
linalg.yield %238 : i1
} -> tensor<4xi1>
%204 = tensor.empty() : tensor<4xf32>
%205 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel"]} ins(%201 : tensor<4xf32>) outs(%204 : tensor<4xf32>) {
^bb0(%in: f32, %out: f32):
%238 = arith.negf %in : f32
linalg.yield %238 : f32
} -> tensor<4xf32>
%206 = tensor.empty() : tensor<4xf32>
%207 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins(%arg0, %201 : tensor<4xf32>, tensor<4xf32>) outs(%206 : tensor<4xf32>) {
^bb0(%in: f32, %in_2: f32, %out: f32):
%238 = arith.mulf %in, %in_2 : f32
linalg.yield %238 : f32
} -> tensor<4xf32>
%208 = tensor.empty() : tensor<4xf32>
%209 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel"]} ins(%201 : tensor<4xf32>) outs(%208 : tensor<4xf32>) {
^bb0(%in: f32, %out: f32):
%238 = arith.subf %cst_0, %in : f32
linalg.yield %238 : f32
} -> tensor<4xf32>
%210 = tensor.empty() : tensor<4xf32>
%211 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins(%207, %209 : tensor<4xf32>, tensor<4xf32>) outs(%210 : tensor<4xf32>) {
^bb0(%in: f32, %in_2: f32, %out: f32):
%238 = arith.mulf %in, %in_2 : f32
linalg.yield %238 : f32
} -> tensor<4xf32>
%212 = tensor.empty() : tensor<4xf32>
%213 = linalg.generic {indexing_maps = [#map, #map, #map, #map], iterator_types = ["parallel"]} ins(%203, %205, %211 : tensor<4xi1>, tensor<4xf32>, tensor<4xf32>) outs(%212 : tensor<4xf32>) {
^bb0(%in: i1, %in_2: f32, %in_3: f32, %out: f32):
%238 = arith.select %in, %in_2, %in_3 : f32
linalg.yield %238 : f32
} -> tensor<4xf32>
%214 = tensor.empty() : tensor<4xi1>
%215 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel"]} ins(%213 : tensor<4xf32>) outs(%214 : tensor<4xi1>) {
^bb0(%in: f32, %out: i1):
%238 = arith.cmpf ogt, %in, %cst_1 : f32
linalg.yield %238 : i1
} -> tensor<4xi1>
%216 = tensor.empty() : tensor<4xf32>
%217 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel"]} ins(%213 : tensor<4xf32>) outs(%216 : tensor<4xf32>) {
^bb0(%in: f32, %out: f32):
%238 = arith.negf %in : f32
linalg.yield %238 : f32
} -> tensor<4xf32>
%218 = tensor.empty() : tensor<4xf32>
%219 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins(%arg0, %213 : tensor<4xf32>, tensor<4xf32>) outs(%218 : tensor<4xf32>) {
^bb0(%in: f32, %in_2: f32, %out: f32):
%238 = arith.mulf %in, %in_2 : f32
linalg.yield %238 : f32
} -> tensor<4xf32>
%220 = tensor.empty() : tensor<4xf32>
%221 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel"]} ins(%213 : tensor<4xf32>) outs(%220 : tensor<4xf32>) {
^bb0(%in: f32, %out: f32):
%238 = arith.subf %cst_0, %in : f32
linalg.yield %238 : f32
} -> tensor<4xf32>
%222 = tensor.empty() : tensor<4xf32>
%223 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins(%219, %221 : tensor<4xf32>, tensor<4xf32>) outs(%222 : tensor<4xf32>) {
^bb0(%in: f32, %in_2: f32, %out: f32):
%238 = arith.mulf %in, %in_2 : f32
linalg.yield %238 : f32
} -> tensor<4xf32>
%224 = tensor.empty() : tensor<4xf32>
%225 = linalg.generic {indexing_maps = [#map, #map, #map, #map], iterator_types = ["parallel"]} ins(%215, %217, %223 : tensor<4xi1>, tensor<4xf32>, tensor<4xf32>) outs(%224 : tensor<4xf32>) {
^bb0(%in: i1, %in_2: f32, %in_3: f32, %out: f32):
%238 = arith.select %in, %in_2, %in_3 : f32
linalg.yield %238 : f32
} -> tensor<4xf32>
%226 = tensor.empty() : tensor<4xi1>
%227 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel"]} ins(%225 : tensor<4xf32>) outs(%226 : tensor<4xi1>) {
^bb0(%in: f32, %out: i1):
%238 = arith.cmpf ogt, %in, %cst_1 : f32
linalg.yield %238 : i1
} -> tensor<4xi1>
%228 = tensor.empty() : tensor<4xf32>
%229 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel"]} ins(%225 : tensor<4xf32>) outs(%228 : tensor<4xf32>) {
^bb0(%in: f32, %out: f32):
%238 = arith.negf %in : f32
linalg.yield %238 : f32
} -> tensor<4xf32>
%230 = tensor.empty() : tensor<4xf32>
%231 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins(%arg0, %225 : tensor<4xf32>, tensor<4xf32>) outs(%230 : tensor<4xf32>) {
^bb0(%in: f32, %in_2: f32, %out: f32):
%238 = arith.mulf %in, %in_2 : f32
linalg.yield %238 : f32
} -> tensor<4xf32>
%232 = tensor.empty() : tensor<4xf32>
%233 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel"]} ins(%225 : tensor<4xf32>) outs(%232 : tensor<4xf32>) {
^bb0(%in: f32, %out: f32):
%238 = arith.subf %cst_0, %in : f32
linalg.yield %238 : f32
} -> tensor<4xf32>
%234 = tensor.empty() : tensor<4xf32>
%235 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins(%231, %233 : tensor<4xf32>, tensor<4xf32>) outs(%234 : tensor<4xf32>) {
^bb0(%in: f32, %in_2: f32, %out: f32):
%238 = arith.mulf %in, %in_2 : f32
linalg.yield %238 : f32
} -> tensor<4xf32>
%236 = tensor.empty() : tensor<4xf32>
%237 = linalg.generic {indexing_maps = [#map, #map, #map, #map], iterator_types = ["parallel"]} ins(%227, %229, %235 : tensor<4xi1>, tensor<4xf32>, tensor<4xf32>) outs(%236 : tensor<4xf32>) {
^bb0(%in: i1, %in_2: f32, %in_3: f32, %out: f32):
%238 = arith.select %in, %in_2, %in_3 : f32
linalg.yield %238 : f32
} -> tensor<4xf32>
return %237 : tensor<4xf32>
}
}
@llandsmeer
Copy link
Copy Markdown
Author

//===------------------------------------------------------------*- C++ -*-===//
//
// Automatically generated file for High-level Synthesis (HLS).
//
//===----------------------------------------------------------------------===//
#include <hls_stream.h>
#include <hls_half.h>
#include <cassert>
#include <hls_math.h>
#include <stdint.h>
#include <stdio.h>
#include <string.h>

using namespace std;

void node0(
  hls::stream<float> &v0,
  float v1[4]
) {
  loop0: for (int v2 = 0; v2 < 4; v2++) {
    #pragma HLS pipeline II=1
    #pragma HLS loop_flatten
    loop1: for (int v3 = 0; v3 < 1; v3++) {
      float v4 = v0.read();
      v1[(v2 + v3)] = v4;
    }
  }
  return ;
}

void node1(
  float v5[4],
  hls::stream<float> &v6,
  float v7,
  float v8
) {
  loop2: for (int v9 = 0; v9 < 4; v9++) {
    #pragma HLS pipeline II=1
    #pragma HLS loop_flatten
    loop3: for (int v10 = 0; v10 < 1; v10++) {
      float v11 = v5[(v9 + v10)];
      float v12 = v11 * v8;
      float v13 = v12 * v8;
      float v14 = v7 - v13;
      float v15 = v11 * v13;
      float v16 = v15 * v14;
      float v17 = -(v13);
      bool v18 = v13 > v8;
      float v19 = v18 ? (float)v17 : (float)v16;
      float v20 = v7 - v19;
      float v21 = v11 * v19;
      float v22 = v21 * v20;
      float v23 = -(v19);
      bool v24 = v19 > v8;
      float v25 = v24 ? (float)v23 : (float)v22;
      float v26 = v7 - v25;
      float v27 = v11 * v25;
      float v28 = v27 * v26;
      float v29 = -(v25);
      bool v30 = v25 > v8;
      float v31 = v30 ? (float)v29 : (float)v28;
      float v32 = v7 - v31;
      float v33 = v11 * v31;
      float v34 = v33 * v32;
      float v35 = -(v31);
      bool v36 = v31 > v8;
      float v37 = v36 ? (float)v35 : (float)v34;
      float v38 = v7 - v37;
      float v39 = v11 * v37;
      float v40 = v39 * v38;
      float v41 = -(v37);
      bool v42 = v37 > v8;
      float v43 = v42 ? (float)v41 : (float)v40;
      float v44 = v7 - v43;
      float v45 = v11 * v43;
      float v46 = v45 * v44;
      float v47 = -(v43);
      bool v48 = v43 > v8;
      float v49 = v48 ? (float)v47 : (float)v46;
      float v50 = v7 - v49;
      float v51 = v11 * v49;
      float v52 = v51 * v50;
      float v53 = -(v49);
      bool v54 = v49 > v8;
      float v55 = v54 ? (float)v53 : (float)v52;
      float v56 = v7 - v55;
      float v57 = v11 * v55;
      float v58 = v57 * v56;
      float v59 = -(v55);
      bool v60 = v55 > v8;
      float v61 = v60 ? (float)v59 : (float)v58;
      float v62 = v7 - v61;
      float v63 = v11 * v61;
      float v64 = v63 * v62;
      float v65 = -(v61);
      bool v66 = v61 > v8;
      float v67 = v66 ? (float)v65 : (float)v64;
      float v68 = v7 - v67;
      float v69 = v11 * v67;
      float v70 = v69 * v68;
      float v71 = -(v67);
      bool v72 = v67 > v8;
      float v73 = v72 ? (float)v71 : (float)v70;
      float v74 = v7 - v73;
      float v75 = v11 * v73;
      float v76 = v75 * v74;
      float v77 = -(v73);
      bool v78 = v73 > v8;
      float v79 = v78 ? (float)v77 : (float)v76;
      float v80 = v7 - v79;
      float v81 = v11 * v79;
      float v82 = v81 * v80;
      float v83 = -(v79);
      bool v84 = v79 > v8;
      float v85 = v84 ? (float)v83 : (float)v82;
      float v86 = v7 - v85;
      float v87 = v11 * v85;
      float v88 = v87 * v86;
      float v89 = -(v85);
      bool v90 = v85 > v8;
      float v91 = v90 ? (float)v89 : (float)v88;
      float v92 = v7 - v91;
      float v93 = v11 * v91;
      float v94 = v93 * v92;
      float v95 = -(v91);
      bool v96 = v91 > v8;
      float v97 = v96 ? (float)v95 : (float)v94;
      float v98 = v7 - v97;
      float v99 = v11 * v97;
      float v100 = v99 * v98;
      float v101 = -(v97);
      bool v102 = v97 > v8;
      float v103 = v102 ? (float)v101 : (float)v100;
      float v104 = v7 - v103;
      float v105 = v11 * v103;
      float v106 = v105 * v104;
      float v107 = -(v103);
      bool v108 = v103 > v8;
      float v109 = v108 ? (float)v107 : (float)v106;
      float v110 = v7 - v109;
      float v111 = v11 * v109;
      float v112 = v111 * v110;
      float v113 = -(v109);
      bool v114 = v109 > v8;
      float v115 = v114 ? (float)v113 : (float)v112;
      float v116 = v7 - v115;
      float v117 = v11 * v115;
      float v118 = v117 * v116;
      float v119 = -(v115);
      bool v120 = v115 > v8;
      float v121 = v120 ? (float)v119 : (float)v118;
      float v122 = v7 - v121;
      float v123 = v11 * v121;
      float v124 = v123 * v122;
      float v125 = -(v121);
      bool v126 = v121 > v8;
      float v127 = v126 ? (float)v125 : (float)v124;
      v6.write(v127);
    }
  }
  return ;
}

void forward(
  float v128[4],
  float v129[4]
) {
	#pragma HLS DATAFLOW
  hls::stream<float> v130;
	#pragma HLS STREAM variable=v130 depth=4
  node1(v128, v130, 1.000000, 0.500000);
  node0(v130, v129);
  return ;
}

@SuhailB
Copy link
Copy Markdown

SuhailB commented May 14, 2025

Hi @llandsmeer,

Thank you for your interest and for pointing this out. I added support for any function name, while keeping the default as forward. However, if the kernel name is "main," then there will be conflict with the host code which has the main function. So, the input linalg.mlir top function can have any name except "main" for the flows to work.

#map = affine_map<(d0) -> (d0)>
module @jit_f {
  // renamed main -> kernel
  func.func public @kernel(%arg0: tensor<4xf32>) -> tensor<4xf32> {
    %cst = arith.constant 1.000000e+00 : f32
    %0 = tensor.empty() : tensor<4xf32>
    %1 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel"]} ins(%arg0 : tensor<4xf32>) outs(%0 : tensor<4xf32>) {
    ^bb0(%in: f32, %out: f32):
      %2 = arith.addf %in, %cst : f32
      linalg.yield %2 : f32
    } -> tensor<4xf32>
    return %1 : tensor<4xf32>
  }
}

run.sh

kernel_name=kernel

streamhls-opt ./jax.linalg.mlir \
    -streamhls-host-pipeline="top-func=$kernel_name" \ # updated
    > host.mlir

streamhls-translate ./host.mlir \
    -emit-vivado-hls \
    -vitis-hls-weights-dir=data \
    -vitis-hls-is-host=true \
    -o host_tb.cpp


streamhls-opt jax.linalg.mlir \
  -streamhls-kernel-pipeline="top-func=$kernel_name \ # updated
    graph-file=graph\
    report-file=report\
    optimize-schedule=1\
    parallelize-nodes=1\
    combined-optimization=0\
    board-dsps=1024 \
    tiling-limit=16 \
    time-limit-minutes=10 \
    bufferize-func-args=0 \
    optimize-conv-reuse=0 \
    minimize-on-chip-buffers=0 \
    debug-point=14" > kernel.mlir

streamhls-translate \
    kernel.mlir \
    -emit-vivado-hls \
    -vitis-hls-top-func=$kernel_name \ # updated
    -o kernel.cpp

expected output:

Permutation DesignSpaceSize: 1
Parallelization DesignSpaceSize: 9
Total DesignSpaceSize: 9
Permutation solver: latency: 3
Permutation DesignSpaceSize: 1
Parallelization DesignSpaceSize: 9
Total DesignSpaceSize: 9
Parallelization solver: Parallel Latency: 3
Total DSPs: 2

Please let me know if you face any other issues.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment