Skip to content

Instantly share code, notes, and snippets.

@llandsmeer
Last active May 14, 2025 16:00
Show Gist options
  • Save llandsmeer/240f5718350fef6feca57010ba4888b3 to your computer and use it in GitHub Desktop.
Save llandsmeer/240f5718350fef6feca57010ba4888b3 to your computer and use it in GitHub Desktop.
StreamHLS with JAX
set -ex
. ../env/bin/activate
python3 jax_test.py
iree-opt ./jax.stablehlo.mlir \
--iree-stablehlo-input-transformation-pipeline \
--convert-scf-to-cf \
> jax.linalg.mlir
streamhls-opt ./jax.linalg.mlir \
-streamhls-host-pipeline \
> host.mlir
streamhls-translate ./host.mlir \
-emit-vivado-hls \
-vitis-hls-weights-dir=data \
-vitis-hls-is-host=true \
-o host_tb.cpp
# put top-func=main
streamhls-opt jax.linalg.mlir \
-streamhls-kernel-pipeline="top-func=main \
graph-file=graph\
report-file=report\
optimize-schedule=1\
parallelize-nodes=1\
combined-optimization=0\
board-dsps=1024 \
tiling-limit=16 \
time-limit-minutes=10 \
bufferize-func-args=0 \
optimize-conv-reuse=0 \
minimize-on-chip-buffers=0 \
debug-point=14" > kernel.mlir
streamhls-translate \
kernel.mlir \
-emit-vivado-hls \
-o kernel.cpp
from jax import numpy as jnp
from jax import jit
import jax
def f(rs):
o = jax.vmap(lambda r:
jax.lax.scan(lambda x, _:
(jax.lax.select(
x > 0.5,
-x,
r*x*(1-x)),
x)
, 0.5, length=20, unroll=True)[0]
)(rs)
return o
# return jax.lax.select(x > 3, x - 5, x + 5)
r# eturn x + 1
x = jnp.ones((4,))
mlir = jax.jit(f).lower(x).as_text()
with open('jax.stablehlo.mlir', 'w') as f:
print(mlir.replace('@main(', '@forward('), file=f)
print(mlir)
#map = affine_map<(d0) -> (d0)>
module @jit_f {
func.func public @forward(%arg0: tensor<4xf32>) -> tensor<4xf32> {
%cst = arith.constant -5.000000e-01 : f32
%false = arith.constant false
%cst_0 = arith.constant 1.000000e+00 : f32
%cst_1 = arith.constant 5.000000e-01 : f32
%0 = tensor.empty() : tensor<4xi1>
%1 = linalg.generic {indexing_maps = [#map], iterator_types = ["parallel"]} outs(%0 : tensor<4xi1>) {
^bb0(%out: i1):
linalg.yield %false : i1
} -> tensor<4xi1>
%2 = tensor.empty() : tensor<4xf32>
%3 = linalg.generic {indexing_maps = [#map], iterator_types = ["parallel"]} outs(%2 : tensor<4xf32>) {
^bb0(%out: f32):
linalg.yield %cst : f32
} -> tensor<4xf32>
%4 = tensor.empty() : tensor<4xf32>
%5 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel"]} ins(%arg0 : tensor<4xf32>) outs(%4 : tensor<4xf32>) {
^bb0(%in: f32, %out: f32):
%238 = arith.mulf %in, %cst_1 : f32
linalg.yield %238 : f32
} -> tensor<4xf32>
%6 = tensor.empty() : tensor<4xf32>
%7 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel"]} ins(%5 : tensor<4xf32>) outs(%6 : tensor<4xf32>) {
^bb0(%in: f32, %out: f32):
%238 = arith.mulf %in, %cst_1 : f32
linalg.yield %238 : f32
} -> tensor<4xf32>
%8 = tensor.empty() : tensor<4xf32>
%9 = linalg.generic {indexing_maps = [#map, #map, #map, #map], iterator_types = ["parallel"]} ins(%1, %3, %7 : tensor<4xi1>, tensor<4xf32>, tensor<4xf32>) outs(%8 : tensor<4xf32>) {
^bb0(%in: i1, %in_2: f32, %in_3: f32, %out: f32):
%238 = arith.select %in, %in_2, %in_3 : f32
linalg.yield %238 : f32
} -> tensor<4xf32>
%10 = tensor.empty() : tensor<4xi1>
%11 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel"]} ins(%9 : tensor<4xf32>) outs(%10 : tensor<4xi1>) {
^bb0(%in: f32, %out: i1):
%238 = arith.cmpf ogt, %in, %cst_1 : f32
linalg.yield %238 : i1
} -> tensor<4xi1>
%12 = tensor.empty() : tensor<4xf32>
%13 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel"]} ins(%9 : tensor<4xf32>) outs(%12 : tensor<4xf32>) {
^bb0(%in: f32, %out: f32):
%238 = arith.negf %in : f32
linalg.yield %238 : f32
} -> tensor<4xf32>
%14 = tensor.empty() : tensor<4xf32>
%15 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins(%arg0, %9 : tensor<4xf32>, tensor<4xf32>) outs(%14 : tensor<4xf32>) {
^bb0(%in: f32, %in_2: f32, %out: f32):
%238 = arith.mulf %in, %in_2 : f32
linalg.yield %238 : f32
} -> tensor<4xf32>
%16 = tensor.empty() : tensor<4xf32>
%17 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel"]} ins(%9 : tensor<4xf32>) outs(%16 : tensor<4xf32>) {
^bb0(%in: f32, %out: f32):
%238 = arith.subf %cst_0, %in : f32
linalg.yield %238 : f32
} -> tensor<4xf32>
%18 = tensor.empty() : tensor<4xf32>
%19 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins(%15, %17 : tensor<4xf32>, tensor<4xf32>) outs(%18 : tensor<4xf32>) {
^bb0(%in: f32, %in_2: f32, %out: f32):
%238 = arith.mulf %in, %in_2 : f32
linalg.yield %238 : f32
} -> tensor<4xf32>
%20 = tensor.empty() : tensor<4xf32>
%21 = linalg.generic {indexing_maps = [#map, #map, #map, #map], iterator_types = ["parallel"]} ins(%11, %13, %19 : tensor<4xi1>, tensor<4xf32>, tensor<4xf32>) outs(%20 : tensor<4xf32>) {
^bb0(%in: i1, %in_2: f32, %in_3: f32, %out: f32):
%238 = arith.select %in, %in_2, %in_3 : f32
linalg.yield %238 : f32
} -> tensor<4xf32>
%22 = tensor.empty() : tensor<4xi1>
%23 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel"]} ins(%21 : tensor<4xf32>) outs(%22 : tensor<4xi1>) {
^bb0(%in: f32, %out: i1):
%238 = arith.cmpf ogt, %in, %cst_1 : f32
linalg.yield %238 : i1
} -> tensor<4xi1>
%24 = tensor.empty() : tensor<4xf32>
%25 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel"]} ins(%21 : tensor<4xf32>) outs(%24 : tensor<4xf32>) {
^bb0(%in: f32, %out: f32):
%238 = arith.negf %in : f32
linalg.yield %238 : f32
} -> tensor<4xf32>
%26 = tensor.empty() : tensor<4xf32>
%27 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins(%arg0, %21 : tensor<4xf32>, tensor<4xf32>) outs(%26 : tensor<4xf32>) {
^bb0(%in: f32, %in_2: f32, %out: f32):
%238 = arith.mulf %in, %in_2 : f32
linalg.yield %238 : f32
} -> tensor<4xf32>
%28 = tensor.empty() : tensor<4xf32>
%29 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel"]} ins(%21 : tensor<4xf32>) outs(%28 : tensor<4xf32>) {
^bb0(%in: f32, %out: f32):
%238 = arith.subf %cst_0, %in : f32
linalg.yield %238 : f32
} -> tensor<4xf32>
%30 = tensor.empty() : tensor<4xf32>
%31 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins(%27, %29 : tensor<4xf32>, tensor<4xf32>) outs(%30 : tensor<4xf32>) {
^bb0(%in: f32, %in_2: f32, %out: f32):
%238 = arith.mulf %in, %in_2 : f32
linalg.yield %238 : f32
} -> tensor<4xf32>
%32 = tensor.empty() : tensor<4xf32>
%33 = linalg.generic {indexing_maps = [#map, #map, #map, #map], iterator_types = ["parallel"]} ins(%23, %25, %31 : tensor<4xi1>, tensor<4xf32>, tensor<4xf32>) outs(%32 : tensor<4xf32>) {
^bb0(%in: i1, %in_2: f32, %in_3: f32, %out: f32):
%238 = arith.select %in, %in_2, %in_3 : f32
linalg.yield %238 : f32
} -> tensor<4xf32>
%34 = tensor.empty() : tensor<4xi1>
%35 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel"]} ins(%33 : tensor<4xf32>) outs(%34 : tensor<4xi1>) {
^bb0(%in: f32, %out: i1):
%238 = arith.cmpf ogt, %in, %cst_1 : f32
linalg.yield %238 : i1
} -> tensor<4xi1>
%36 = tensor.empty() : tensor<4xf32>
%37 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel"]} ins(%33 : tensor<4xf32>) outs(%36 : tensor<4xf32>) {
^bb0(%in: f32, %out: f32):
%238 = arith.negf %in : f32
linalg.yield %238 : f32
} -> tensor<4xf32>
%38 = tensor.empty() : tensor<4xf32>
%39 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins(%arg0, %33 : tensor<4xf32>, tensor<4xf32>) outs(%38 : tensor<4xf32>) {
^bb0(%in: f32, %in_2: f32, %out: f32):
%238 = arith.mulf %in, %in_2 : f32
linalg.yield %238 : f32
} -> tensor<4xf32>
%40 = tensor.empty() : tensor<4xf32>
%41 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel"]} ins(%33 : tensor<4xf32>) outs(%40 : tensor<4xf32>) {
^bb0(%in: f32, %out: f32):
%238 = arith.subf %cst_0, %in : f32
linalg.yield %238 : f32
} -> tensor<4xf32>
%42 = tensor.empty() : tensor<4xf32>
%43 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins(%39, %41 : tensor<4xf32>, tensor<4xf32>) outs(%42 : tensor<4xf32>) {
^bb0(%in: f32, %in_2: f32, %out: f32):
%238 = arith.mulf %in, %in_2 : f32
linalg.yield %238 : f32
} -> tensor<4xf32>
%44 = tensor.empty() : tensor<4xf32>
%45 = linalg.generic {indexing_maps = [#map, #map, #map, #map], iterator_types = ["parallel"]} ins(%35, %37, %43 : tensor<4xi1>, tensor<4xf32>, tensor<4xf32>) outs(%44 : tensor<4xf32>) {
^bb0(%in: i1, %in_2: f32, %in_3: f32, %out: f32):
%238 = arith.select %in, %in_2, %in_3 : f32
linalg.yield %238 : f32
} -> tensor<4xf32>
%46 = tensor.empty() : tensor<4xi1>
%47 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel"]} ins(%45 : tensor<4xf32>) outs(%46 : tensor<4xi1>) {
^bb0(%in: f32, %out: i1):
%238 = arith.cmpf ogt, %in, %cst_1 : f32
linalg.yield %238 : i1
} -> tensor<4xi1>
%48 = tensor.empty() : tensor<4xf32>
%49 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel"]} ins(%45 : tensor<4xf32>) outs(%48 : tensor<4xf32>) {
^bb0(%in: f32, %out: f32):
%238 = arith.negf %in : f32
linalg.yield %238 : f32
} -> tensor<4xf32>
%50 = tensor.empty() : tensor<4xf32>
%51 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins(%arg0, %45 : tensor<4xf32>, tensor<4xf32>) outs(%50 : tensor<4xf32>) {
^bb0(%in: f32, %in_2: f32, %out: f32):
%238 = arith.mulf %in, %in_2 : f32
linalg.yield %238 : f32
} -> tensor<4xf32>
%52 = tensor.empty() : tensor<4xf32>
%53 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel"]} ins(%45 : tensor<4xf32>) outs(%52 : tensor<4xf32>) {
^bb0(%in: f32, %out: f32):
%238 = arith.subf %cst_0, %in : f32
linalg.yield %238 : f32
} -> tensor<4xf32>
%54 = tensor.empty() : tensor<4xf32>
%55 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins(%51, %53 : tensor<4xf32>, tensor<4xf32>) outs(%54 : tensor<4xf32>) {
^bb0(%in: f32, %in_2: f32, %out: f32):
%238 = arith.mulf %in, %in_2 : f32
linalg.yield %238 : f32
} -> tensor<4xf32>
%56 = tensor.empty() : tensor<4xf32>
%57 = linalg.generic {indexing_maps = [#map, #map, #map, #map], iterator_types = ["parallel"]} ins(%47, %49, %55 : tensor<4xi1>, tensor<4xf32>, tensor<4xf32>) outs(%56 : tensor<4xf32>) {
^bb0(%in: i1, %in_2: f32, %in_3: f32, %out: f32):
%238 = arith.select %in, %in_2, %in_3 : f32
linalg.yield %238 : f32
} -> tensor<4xf32>
%58 = tensor.empty() : tensor<4xi1>
%59 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel"]} ins(%57 : tensor<4xf32>) outs(%58 : tensor<4xi1>) {
^bb0(%in: f32, %out: i1):
%238 = arith.cmpf ogt, %in, %cst_1 : f32
linalg.yield %238 : i1
} -> tensor<4xi1>
%60 = tensor.empty() : tensor<4xf32>
%61 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel"]} ins(%57 : tensor<4xf32>) outs(%60 : tensor<4xf32>) {
^bb0(%in: f32, %out: f32):
%238 = arith.negf %in : f32
linalg.yield %238 : f32
} -> tensor<4xf32>
%62 = tensor.empty() : tensor<4xf32>
%63 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins(%arg0, %57 : tensor<4xf32>, tensor<4xf32>) outs(%62 : tensor<4xf32>) {
^bb0(%in: f32, %in_2: f32, %out: f32):
%238 = arith.mulf %in, %in_2 : f32
linalg.yield %238 : f32
} -> tensor<4xf32>
%64 = tensor.empty() : tensor<4xf32>
%65 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel"]} ins(%57 : tensor<4xf32>) outs(%64 : tensor<4xf32>) {
^bb0(%in: f32, %out: f32):
%238 = arith.subf %cst_0, %in : f32
linalg.yield %238 : f32
} -> tensor<4xf32>
%66 = tensor.empty() : tensor<4xf32>
%67 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins(%63, %65 : tensor<4xf32>, tensor<4xf32>) outs(%66 : tensor<4xf32>) {
^bb0(%in: f32, %in_2: f32, %out: f32):
%238 = arith.mulf %in, %in_2 : f32
linalg.yield %238 : f32
} -> tensor<4xf32>
%68 = tensor.empty() : tensor<4xf32>
%69 = linalg.generic {indexing_maps = [#map, #map, #map, #map], iterator_types = ["parallel"]} ins(%59, %61, %67 : tensor<4xi1>, tensor<4xf32>, tensor<4xf32>) outs(%68 : tensor<4xf32>) {
^bb0(%in: i1, %in_2: f32, %in_3: f32, %out: f32):
%238 = arith.select %in, %in_2, %in_3 : f32
linalg.yield %238 : f32
} -> tensor<4xf32>
%70 = tensor.empty() : tensor<4xi1>
%71 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel"]} ins(%69 : tensor<4xf32>) outs(%70 : tensor<4xi1>) {
^bb0(%in: f32, %out: i1):
%238 = arith.cmpf ogt, %in, %cst_1 : f32
linalg.yield %238 : i1
} -> tensor<4xi1>
%72 = tensor.empty() : tensor<4xf32>
%73 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel"]} ins(%69 : tensor<4xf32>) outs(%72 : tensor<4xf32>) {
^bb0(%in: f32, %out: f32):
%238 = arith.negf %in : f32
linalg.yield %238 : f32
} -> tensor<4xf32>
%74 = tensor.empty() : tensor<4xf32>
%75 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins(%arg0, %69 : tensor<4xf32>, tensor<4xf32>) outs(%74 : tensor<4xf32>) {
^bb0(%in: f32, %in_2: f32, %out: f32):
%238 = arith.mulf %in, %in_2 : f32
linalg.yield %238 : f32
} -> tensor<4xf32>
%76 = tensor.empty() : tensor<4xf32>
%77 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel"]} ins(%69 : tensor<4xf32>) outs(%76 : tensor<4xf32>) {
^bb0(%in: f32, %out: f32):
%238 = arith.subf %cst_0, %in : f32
linalg.yield %238 : f32
} -> tensor<4xf32>
%78 = tensor.empty() : tensor<4xf32>
%79 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins(%75, %77 : tensor<4xf32>, tensor<4xf32>) outs(%78 : tensor<4xf32>) {
^bb0(%in: f32, %in_2: f32, %out: f32):
%238 = arith.mulf %in, %in_2 : f32
linalg.yield %238 : f32
} -> tensor<4xf32>
%80 = tensor.empty() : tensor<4xf32>
%81 = linalg.generic {indexing_maps = [#map, #map, #map, #map], iterator_types = ["parallel"]} ins(%71, %73, %79 : tensor<4xi1>, tensor<4xf32>, tensor<4xf32>) outs(%80 : tensor<4xf32>) {
^bb0(%in: i1, %in_2: f32, %in_3: f32, %out: f32):
%238 = arith.select %in, %in_2, %in_3 : f32
linalg.yield %238 : f32
} -> tensor<4xf32>
%82 = tensor.empty() : tensor<4xi1>
%83 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel"]} ins(%81 : tensor<4xf32>) outs(%82 : tensor<4xi1>) {
^bb0(%in: f32, %out: i1):
%238 = arith.cmpf ogt, %in, %cst_1 : f32
linalg.yield %238 : i1
} -> tensor<4xi1>
%84 = tensor.empty() : tensor<4xf32>
%85 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel"]} ins(%81 : tensor<4xf32>) outs(%84 : tensor<4xf32>) {
^bb0(%in: f32, %out: f32):
%238 = arith.negf %in : f32
linalg.yield %238 : f32
} -> tensor<4xf32>
%86 = tensor.empty() : tensor<4xf32>
%87 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins(%arg0, %81 : tensor<4xf32>, tensor<4xf32>) outs(%86 : tensor<4xf32>) {
^bb0(%in: f32, %in_2: f32, %out: f32):
%238 = arith.mulf %in, %in_2 : f32
linalg.yield %238 : f32
} -> tensor<4xf32>
%88 = tensor.empty() : tensor<4xf32>
%89 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel"]} ins(%81 : tensor<4xf32>) outs(%88 : tensor<4xf32>) {
^bb0(%in: f32, %out: f32):
%238 = arith.subf %cst_0, %in : f32
linalg.yield %238 : f32
} -> tensor<4xf32>
%90 = tensor.empty() : tensor<4xf32>
%91 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins(%87, %89 : tensor<4xf32>, tensor<4xf32>) outs(%90 : tensor<4xf32>) {
^bb0(%in: f32, %in_2: f32, %out: f32):
%238 = arith.mulf %in, %in_2 : f32
linalg.yield %238 : f32
} -> tensor<4xf32>
%92 = tensor.empty() : tensor<4xf32>
%93 = linalg.generic {indexing_maps = [#map, #map, #map, #map], iterator_types = ["parallel"]} ins(%83, %85, %91 : tensor<4xi1>, tensor<4xf32>, tensor<4xf32>) outs(%92 : tensor<4xf32>) {
^bb0(%in: i1, %in_2: f32, %in_3: f32, %out: f32):
%238 = arith.select %in, %in_2, %in_3 : f32
linalg.yield %238 : f32
} -> tensor<4xf32>
%94 = tensor.empty() : tensor<4xi1>
%95 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel"]} ins(%93 : tensor<4xf32>) outs(%94 : tensor<4xi1>) {
^bb0(%in: f32, %out: i1):
%238 = arith.cmpf ogt, %in, %cst_1 : f32
linalg.yield %238 : i1
} -> tensor<4xi1>
%96 = tensor.empty() : tensor<4xf32>
%97 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel"]} ins(%93 : tensor<4xf32>) outs(%96 : tensor<4xf32>) {
^bb0(%in: f32, %out: f32):
%238 = arith.negf %in : f32
linalg.yield %238 : f32
} -> tensor<4xf32>
%98 = tensor.empty() : tensor<4xf32>
%99 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins(%arg0, %93 : tensor<4xf32>, tensor<4xf32>) outs(%98 : tensor<4xf32>) {
^bb0(%in: f32, %in_2: f32, %out: f32):
%238 = arith.mulf %in, %in_2 : f32
linalg.yield %238 : f32
} -> tensor<4xf32>
%100 = tensor.empty() : tensor<4xf32>
%101 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel"]} ins(%93 : tensor<4xf32>) outs(%100 : tensor<4xf32>) {
^bb0(%in: f32, %out: f32):
%238 = arith.subf %cst_0, %in : f32
linalg.yield %238 : f32
} -> tensor<4xf32>
%102 = tensor.empty() : tensor<4xf32>
%103 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins(%99, %101 : tensor<4xf32>, tensor<4xf32>) outs(%102 : tensor<4xf32>) {
^bb0(%in: f32, %in_2: f32, %out: f32):
%238 = arith.mulf %in, %in_2 : f32
linalg.yield %238 : f32
} -> tensor<4xf32>
%104 = tensor.empty() : tensor<4xf32>
%105 = linalg.generic {indexing_maps = [#map, #map, #map, #map], iterator_types = ["parallel"]} ins(%95, %97, %103 : tensor<4xi1>, tensor<4xf32>, tensor<4xf32>) outs(%104 : tensor<4xf32>) {
^bb0(%in: i1, %in_2: f32, %in_3: f32, %out: f32):
%238 = arith.select %in, %in_2, %in_3 : f32
linalg.yield %238 : f32
} -> tensor<4xf32>
%106 = tensor.empty() : tensor<4xi1>
%107 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel"]} ins(%105 : tensor<4xf32>) outs(%106 : tensor<4xi1>) {
^bb0(%in: f32, %out: i1):
%238 = arith.cmpf ogt, %in, %cst_1 : f32
linalg.yield %238 : i1
} -> tensor<4xi1>
%108 = tensor.empty() : tensor<4xf32>
%109 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel"]} ins(%105 : tensor<4xf32>) outs(%108 : tensor<4xf32>) {
^bb0(%in: f32, %out: f32):
%238 = arith.negf %in : f32
linalg.yield %238 : f32
} -> tensor<4xf32>
%110 = tensor.empty() : tensor<4xf32>
%111 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins(%arg0, %105 : tensor<4xf32>, tensor<4xf32>) outs(%110 : tensor<4xf32>) {
^bb0(%in: f32, %in_2: f32, %out: f32):
%238 = arith.mulf %in, %in_2 : f32
linalg.yield %238 : f32
} -> tensor<4xf32>
%112 = tensor.empty() : tensor<4xf32>
%113 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel"]} ins(%105 : tensor<4xf32>) outs(%112 : tensor<4xf32>) {
^bb0(%in: f32, %out: f32):
%238 = arith.subf %cst_0, %in : f32
linalg.yield %238 : f32
} -> tensor<4xf32>
%114 = tensor.empty() : tensor<4xf32>
%115 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins(%111, %113 : tensor<4xf32>, tensor<4xf32>) outs(%114 : tensor<4xf32>) {
^bb0(%in: f32, %in_2: f32, %out: f32):
%238 = arith.mulf %in, %in_2 : f32
linalg.yield %238 : f32
} -> tensor<4xf32>
%116 = tensor.empty() : tensor<4xf32>
%117 = linalg.generic {indexing_maps = [#map, #map, #map, #map], iterator_types = ["parallel"]} ins(%107, %109, %115 : tensor<4xi1>, tensor<4xf32>, tensor<4xf32>) outs(%116 : tensor<4xf32>) {
^bb0(%in: i1, %in_2: f32, %in_3: f32, %out: f32):
%238 = arith.select %in, %in_2, %in_3 : f32
linalg.yield %238 : f32
} -> tensor<4xf32>
%118 = tensor.empty() : tensor<4xi1>
%119 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel"]} ins(%117 : tensor<4xf32>) outs(%118 : tensor<4xi1>) {
^bb0(%in: f32, %out: i1):
%238 = arith.cmpf ogt, %in, %cst_1 : f32
linalg.yield %238 : i1
} -> tensor<4xi1>
%120 = tensor.empty() : tensor<4xf32>
%121 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel"]} ins(%117 : tensor<4xf32>) outs(%120 : tensor<4xf32>) {
^bb0(%in: f32, %out: f32):
%238 = arith.negf %in : f32
linalg.yield %238 : f32
} -> tensor<4xf32>
%122 = tensor.empty() : tensor<4xf32>
%123 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins(%arg0, %117 : tensor<4xf32>, tensor<4xf32>) outs(%122 : tensor<4xf32>) {
^bb0(%in: f32, %in_2: f32, %out: f32):
%238 = arith.mulf %in, %in_2 : f32
linalg.yield %238 : f32
} -> tensor<4xf32>
%124 = tensor.empty() : tensor<4xf32>
%125 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel"]} ins(%117 : tensor<4xf32>) outs(%124 : tensor<4xf32>) {
^bb0(%in: f32, %out: f32):
%238 = arith.subf %cst_0, %in : f32
linalg.yield %238 : f32
} -> tensor<4xf32>
%126 = tensor.empty() : tensor<4xf32>
%127 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins(%123, %125 : tensor<4xf32>, tensor<4xf32>) outs(%126 : tensor<4xf32>) {
^bb0(%in: f32, %in_2: f32, %out: f32):
%238 = arith.mulf %in, %in_2 : f32
linalg.yield %238 : f32
} -> tensor<4xf32>
%128 = tensor.empty() : tensor<4xf32>
%129 = linalg.generic {indexing_maps = [#map, #map, #map, #map], iterator_types = ["parallel"]} ins(%119, %121, %127 : tensor<4xi1>, tensor<4xf32>, tensor<4xf32>) outs(%128 : tensor<4xf32>) {
^bb0(%in: i1, %in_2: f32, %in_3: f32, %out: f32):
%238 = arith.select %in, %in_2, %in_3 : f32
linalg.yield %238 : f32
} -> tensor<4xf32>
%130 = tensor.empty() : tensor<4xi1>
%131 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel"]} ins(%129 : tensor<4xf32>) outs(%130 : tensor<4xi1>) {
^bb0(%in: f32, %out: i1):
%238 = arith.cmpf ogt, %in, %cst_1 : f32
linalg.yield %238 : i1
} -> tensor<4xi1>
%132 = tensor.empty() : tensor<4xf32>
%133 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel"]} ins(%129 : tensor<4xf32>) outs(%132 : tensor<4xf32>) {
^bb0(%in: f32, %out: f32):
%238 = arith.negf %in : f32
linalg.yield %238 : f32
} -> tensor<4xf32>
%134 = tensor.empty() : tensor<4xf32>
%135 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins(%arg0, %129 : tensor<4xf32>, tensor<4xf32>) outs(%134 : tensor<4xf32>) {
^bb0(%in: f32, %in_2: f32, %out: f32):
%238 = arith.mulf %in, %in_2 : f32
linalg.yield %238 : f32
} -> tensor<4xf32>
%136 = tensor.empty() : tensor<4xf32>
%137 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel"]} ins(%129 : tensor<4xf32>) outs(%136 : tensor<4xf32>) {
^bb0(%in: f32, %out: f32):
%238 = arith.subf %cst_0, %in : f32
linalg.yield %238 : f32
} -> tensor<4xf32>
%138 = tensor.empty() : tensor<4xf32>
%139 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins(%135, %137 : tensor<4xf32>, tensor<4xf32>) outs(%138 : tensor<4xf32>) {
^bb0(%in: f32, %in_2: f32, %out: f32):
%238 = arith.mulf %in, %in_2 : f32
linalg.yield %238 : f32
} -> tensor<4xf32>
%140 = tensor.empty() : tensor<4xf32>
%141 = linalg.generic {indexing_maps = [#map, #map, #map, #map], iterator_types = ["parallel"]} ins(%131, %133, %139 : tensor<4xi1>, tensor<4xf32>, tensor<4xf32>) outs(%140 : tensor<4xf32>) {
^bb0(%in: i1, %in_2: f32, %in_3: f32, %out: f32):
%238 = arith.select %in, %in_2, %in_3 : f32
linalg.yield %238 : f32
} -> tensor<4xf32>
%142 = tensor.empty() : tensor<4xi1>
%143 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel"]} ins(%141 : tensor<4xf32>) outs(%142 : tensor<4xi1>) {
^bb0(%in: f32, %out: i1):
%238 = arith.cmpf ogt, %in, %cst_1 : f32
linalg.yield %238 : i1
} -> tensor<4xi1>
%144 = tensor.empty() : tensor<4xf32>
%145 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel"]} ins(%141 : tensor<4xf32>) outs(%144 : tensor<4xf32>) {
^bb0(%in: f32, %out: f32):
%238 = arith.negf %in : f32
linalg.yield %238 : f32
} -> tensor<4xf32>
%146 = tensor.empty() : tensor<4xf32>
%147 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins(%arg0, %141 : tensor<4xf32>, tensor<4xf32>) outs(%146 : tensor<4xf32>) {
^bb0(%in: f32, %in_2: f32, %out: f32):
%238 = arith.mulf %in, %in_2 : f32
linalg.yield %238 : f32
} -> tensor<4xf32>
%148 = tensor.empty() : tensor<4xf32>
%149 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel"]} ins(%141 : tensor<4xf32>) outs(%148 : tensor<4xf32>) {
^bb0(%in: f32, %out: f32):
%238 = arith.subf %cst_0, %in : f32
linalg.yield %238 : f32
} -> tensor<4xf32>
%150 = tensor.empty() : tensor<4xf32>
%151 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins(%147, %149 : tensor<4xf32>, tensor<4xf32>) outs(%150 : tensor<4xf32>) {
^bb0(%in: f32, %in_2: f32, %out: f32):
%238 = arith.mulf %in, %in_2 : f32
linalg.yield %238 : f32
} -> tensor<4xf32>
%152 = tensor.empty() : tensor<4xf32>
%153 = linalg.generic {indexing_maps = [#map, #map, #map, #map], iterator_types = ["parallel"]} ins(%143, %145, %151 : tensor<4xi1>, tensor<4xf32>, tensor<4xf32>) outs(%152 : tensor<4xf32>) {
^bb0(%in: i1, %in_2: f32, %in_3: f32, %out: f32):
%238 = arith.select %in, %in_2, %in_3 : f32
linalg.yield %238 : f32
} -> tensor<4xf32>
%154 = tensor.empty() : tensor<4xi1>
%155 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel"]} ins(%153 : tensor<4xf32>) outs(%154 : tensor<4xi1>) {
^bb0(%in: f32, %out: i1):
%238 = arith.cmpf ogt, %in, %cst_1 : f32
linalg.yield %238 : i1
} -> tensor<4xi1>
%156 = tensor.empty() : tensor<4xf32>
%157 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel"]} ins(%153 : tensor<4xf32>) outs(%156 : tensor<4xf32>) {
^bb0(%in: f32, %out: f32):
%238 = arith.negf %in : f32
linalg.yield %238 : f32
} -> tensor<4xf32>
%158 = tensor.empty() : tensor<4xf32>
%159 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins(%arg0, %153 : tensor<4xf32>, tensor<4xf32>) outs(%158 : tensor<4xf32>) {
^bb0(%in: f32, %in_2: f32, %out: f32):
%238 = arith.mulf %in, %in_2 : f32
linalg.yield %238 : f32
} -> tensor<4xf32>
%160 = tensor.empty() : tensor<4xf32>
%161 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel"]} ins(%153 : tensor<4xf32>) outs(%160 : tensor<4xf32>) {
^bb0(%in: f32, %out: f32):
%238 = arith.subf %cst_0, %in : f32
linalg.yield %238 : f32
} -> tensor<4xf32>
%162 = tensor.empty() : tensor<4xf32>
%163 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins(%159, %161 : tensor<4xf32>, tensor<4xf32>) outs(%162 : tensor<4xf32>) {
^bb0(%in: f32, %in_2: f32, %out: f32):
%238 = arith.mulf %in, %in_2 : f32
linalg.yield %238 : f32
} -> tensor<4xf32>
%164 = tensor.empty() : tensor<4xf32>
%165 = linalg.generic {indexing_maps = [#map, #map, #map, #map], iterator_types = ["parallel"]} ins(%155, %157, %163 : tensor<4xi1>, tensor<4xf32>, tensor<4xf32>) outs(%164 : tensor<4xf32>) {
^bb0(%in: i1, %in_2: f32, %in_3: f32, %out: f32):
%238 = arith.select %in, %in_2, %in_3 : f32
linalg.yield %238 : f32
} -> tensor<4xf32>
%166 = tensor.empty() : tensor<4xi1>
%167 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel"]} ins(%165 : tensor<4xf32>) outs(%166 : tensor<4xi1>) {
^bb0(%in: f32, %out: i1):
%238 = arith.cmpf ogt, %in, %cst_1 : f32
linalg.yield %238 : i1
} -> tensor<4xi1>
%168 = tensor.empty() : tensor<4xf32>
%169 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel"]} ins(%165 : tensor<4xf32>) outs(%168 : tensor<4xf32>) {
^bb0(%in: f32, %out: f32):
%238 = arith.negf %in : f32
linalg.yield %238 : f32
} -> tensor<4xf32>
%170 = tensor.empty() : tensor<4xf32>
%171 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins(%arg0, %165 : tensor<4xf32>, tensor<4xf32>) outs(%170 : tensor<4xf32>) {
^bb0(%in: f32, %in_2: f32, %out: f32):
%238 = arith.mulf %in, %in_2 : f32
linalg.yield %238 : f32
} -> tensor<4xf32>
%172 = tensor.empty() : tensor<4xf32>
%173 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel"]} ins(%165 : tensor<4xf32>) outs(%172 : tensor<4xf32>) {
^bb0(%in: f32, %out: f32):
%238 = arith.subf %cst_0, %in : f32
linalg.yield %238 : f32
} -> tensor<4xf32>
%174 = tensor.empty() : tensor<4xf32>
%175 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins(%171, %173 : tensor<4xf32>, tensor<4xf32>) outs(%174 : tensor<4xf32>) {
^bb0(%in: f32, %in_2: f32, %out: f32):
%238 = arith.mulf %in, %in_2 : f32
linalg.yield %238 : f32
} -> tensor<4xf32>
%176 = tensor.empty() : tensor<4xf32>
%177 = linalg.generic {indexing_maps = [#map, #map, #map, #map], iterator_types = ["parallel"]} ins(%167, %169, %175 : tensor<4xi1>, tensor<4xf32>, tensor<4xf32>) outs(%176 : tensor<4xf32>) {
^bb0(%in: i1, %in_2: f32, %in_3: f32, %out: f32):
%238 = arith.select %in, %in_2, %in_3 : f32
linalg.yield %238 : f32
} -> tensor<4xf32>
%178 = tensor.empty() : tensor<4xi1>
%179 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel"]} ins(%177 : tensor<4xf32>) outs(%178 : tensor<4xi1>) {
^bb0(%in: f32, %out: i1):
%238 = arith.cmpf ogt, %in, %cst_1 : f32
linalg.yield %238 : i1
} -> tensor<4xi1>
%180 = tensor.empty() : tensor<4xf32>
%181 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel"]} ins(%177 : tensor<4xf32>) outs(%180 : tensor<4xf32>) {
^bb0(%in: f32, %out: f32):
%238 = arith.negf %in : f32
linalg.yield %238 : f32
} -> tensor<4xf32>
%182 = tensor.empty() : tensor<4xf32>
%183 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins(%arg0, %177 : tensor<4xf32>, tensor<4xf32>) outs(%182 : tensor<4xf32>) {
^bb0(%in: f32, %in_2: f32, %out: f32):
%238 = arith.mulf %in, %in_2 : f32
linalg.yield %238 : f32
} -> tensor<4xf32>
%184 = tensor.empty() : tensor<4xf32>
%185 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel"]} ins(%177 : tensor<4xf32>) outs(%184 : tensor<4xf32>) {
^bb0(%in: f32, %out: f32):
%238 = arith.subf %cst_0, %in : f32
linalg.yield %238 : f32
} -> tensor<4xf32>
%186 = tensor.empty() : tensor<4xf32>
%187 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins(%183, %185 : tensor<4xf32>, tensor<4xf32>) outs(%186 : tensor<4xf32>) {
^bb0(%in: f32, %in_2: f32, %out: f32):
%238 = arith.mulf %in, %in_2 : f32
linalg.yield %238 : f32
} -> tensor<4xf32>
%188 = tensor.empty() : tensor<4xf32>
%189 = linalg.generic {indexing_maps = [#map, #map, #map, #map], iterator_types = ["parallel"]} ins(%179, %181, %187 : tensor<4xi1>, tensor<4xf32>, tensor<4xf32>) outs(%188 : tensor<4xf32>) {
^bb0(%in: i1, %in_2: f32, %in_3: f32, %out: f32):
%238 = arith.select %in, %in_2, %in_3 : f32
linalg.yield %238 : f32
} -> tensor<4xf32>
%190 = tensor.empty() : tensor<4xi1>
%191 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel"]} ins(%189 : tensor<4xf32>) outs(%190 : tensor<4xi1>) {
^bb0(%in: f32, %out: i1):
%238 = arith.cmpf ogt, %in, %cst_1 : f32
linalg.yield %238 : i1
} -> tensor<4xi1>
%192 = tensor.empty() : tensor<4xf32>
%193 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel"]} ins(%189 : tensor<4xf32>) outs(%192 : tensor<4xf32>) {
^bb0(%in: f32, %out: f32):
%238 = arith.negf %in : f32
linalg.yield %238 : f32
} -> tensor<4xf32>
%194 = tensor.empty() : tensor<4xf32>
%195 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins(%arg0, %189 : tensor<4xf32>, tensor<4xf32>) outs(%194 : tensor<4xf32>) {
^bb0(%in: f32, %in_2: f32, %out: f32):
%238 = arith.mulf %in, %in_2 : f32
linalg.yield %238 : f32
} -> tensor<4xf32>
%196 = tensor.empty() : tensor<4xf32>
%197 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel"]} ins(%189 : tensor<4xf32>) outs(%196 : tensor<4xf32>) {
^bb0(%in: f32, %out: f32):
%238 = arith.subf %cst_0, %in : f32
linalg.yield %238 : f32
} -> tensor<4xf32>
%198 = tensor.empty() : tensor<4xf32>
%199 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins(%195, %197 : tensor<4xf32>, tensor<4xf32>) outs(%198 : tensor<4xf32>) {
^bb0(%in: f32, %in_2: f32, %out: f32):
%238 = arith.mulf %in, %in_2 : f32
linalg.yield %238 : f32
} -> tensor<4xf32>
%200 = tensor.empty() : tensor<4xf32>
%201 = linalg.generic {indexing_maps = [#map, #map, #map, #map], iterator_types = ["parallel"]} ins(%191, %193, %199 : tensor<4xi1>, tensor<4xf32>, tensor<4xf32>) outs(%200 : tensor<4xf32>) {
^bb0(%in: i1, %in_2: f32, %in_3: f32, %out: f32):
%238 = arith.select %in, %in_2, %in_3 : f32
linalg.yield %238 : f32
} -> tensor<4xf32>
%202 = tensor.empty() : tensor<4xi1>
%203 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel"]} ins(%201 : tensor<4xf32>) outs(%202 : tensor<4xi1>) {
^bb0(%in: f32, %out: i1):
%238 = arith.cmpf ogt, %in, %cst_1 : f32
linalg.yield %238 : i1
} -> tensor<4xi1>
%204 = tensor.empty() : tensor<4xf32>
%205 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel"]} ins(%201 : tensor<4xf32>) outs(%204 : tensor<4xf32>) {
^bb0(%in: f32, %out: f32):
%238 = arith.negf %in : f32
linalg.yield %238 : f32
} -> tensor<4xf32>
%206 = tensor.empty() : tensor<4xf32>
%207 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins(%arg0, %201 : tensor<4xf32>, tensor<4xf32>) outs(%206 : tensor<4xf32>) {
^bb0(%in: f32, %in_2: f32, %out: f32):
%238 = arith.mulf %in, %in_2 : f32
linalg.yield %238 : f32
} -> tensor<4xf32>
%208 = tensor.empty() : tensor<4xf32>
%209 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel"]} ins(%201 : tensor<4xf32>) outs(%208 : tensor<4xf32>) {
^bb0(%in: f32, %out: f32):
%238 = arith.subf %cst_0, %in : f32
linalg.yield %238 : f32
} -> tensor<4xf32>
%210 = tensor.empty() : tensor<4xf32>
%211 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins(%207, %209 : tensor<4xf32>, tensor<4xf32>) outs(%210 : tensor<4xf32>) {
^bb0(%in: f32, %in_2: f32, %out: f32):
%238 = arith.mulf %in, %in_2 : f32
linalg.yield %238 : f32
} -> tensor<4xf32>
%212 = tensor.empty() : tensor<4xf32>
%213 = linalg.generic {indexing_maps = [#map, #map, #map, #map], iterator_types = ["parallel"]} ins(%203, %205, %211 : tensor<4xi1>, tensor<4xf32>, tensor<4xf32>) outs(%212 : tensor<4xf32>) {
^bb0(%in: i1, %in_2: f32, %in_3: f32, %out: f32):
%238 = arith.select %in, %in_2, %in_3 : f32
linalg.yield %238 : f32
} -> tensor<4xf32>
%214 = tensor.empty() : tensor<4xi1>
%215 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel"]} ins(%213 : tensor<4xf32>) outs(%214 : tensor<4xi1>) {
^bb0(%in: f32, %out: i1):
%238 = arith.cmpf ogt, %in, %cst_1 : f32
linalg.yield %238 : i1
} -> tensor<4xi1>
%216 = tensor.empty() : tensor<4xf32>
%217 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel"]} ins(%213 : tensor<4xf32>) outs(%216 : tensor<4xf32>) {
^bb0(%in: f32, %out: f32):
%238 = arith.negf %in : f32
linalg.yield %238 : f32
} -> tensor<4xf32>
%218 = tensor.empty() : tensor<4xf32>
%219 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins(%arg0, %213 : tensor<4xf32>, tensor<4xf32>) outs(%218 : tensor<4xf32>) {
^bb0(%in: f32, %in_2: f32, %out: f32):
%238 = arith.mulf %in, %in_2 : f32
linalg.yield %238 : f32
} -> tensor<4xf32>
%220 = tensor.empty() : tensor<4xf32>
%221 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel"]} ins(%213 : tensor<4xf32>) outs(%220 : tensor<4xf32>) {
^bb0(%in: f32, %out: f32):
%238 = arith.subf %cst_0, %in : f32
linalg.yield %238 : f32
} -> tensor<4xf32>
%222 = tensor.empty() : tensor<4xf32>
%223 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins(%219, %221 : tensor<4xf32>, tensor<4xf32>) outs(%222 : tensor<4xf32>) {
^bb0(%in: f32, %in_2: f32, %out: f32):
%238 = arith.mulf %in, %in_2 : f32
linalg.yield %238 : f32
} -> tensor<4xf32>
%224 = tensor.empty() : tensor<4xf32>
%225 = linalg.generic {indexing_maps = [#map, #map, #map, #map], iterator_types = ["parallel"]} ins(%215, %217, %223 : tensor<4xi1>, tensor<4xf32>, tensor<4xf32>) outs(%224 : tensor<4xf32>) {
^bb0(%in: i1, %in_2: f32, %in_3: f32, %out: f32):
%238 = arith.select %in, %in_2, %in_3 : f32
linalg.yield %238 : f32
} -> tensor<4xf32>
%226 = tensor.empty() : tensor<4xi1>
%227 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel"]} ins(%225 : tensor<4xf32>) outs(%226 : tensor<4xi1>) {
^bb0(%in: f32, %out: i1):
%238 = arith.cmpf ogt, %in, %cst_1 : f32
linalg.yield %238 : i1
} -> tensor<4xi1>
%228 = tensor.empty() : tensor<4xf32>
%229 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel"]} ins(%225 : tensor<4xf32>) outs(%228 : tensor<4xf32>) {
^bb0(%in: f32, %out: f32):
%238 = arith.negf %in : f32
linalg.yield %238 : f32
} -> tensor<4xf32>
%230 = tensor.empty() : tensor<4xf32>
%231 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins(%arg0, %225 : tensor<4xf32>, tensor<4xf32>) outs(%230 : tensor<4xf32>) {
^bb0(%in: f32, %in_2: f32, %out: f32):
%238 = arith.mulf %in, %in_2 : f32
linalg.yield %238 : f32
} -> tensor<4xf32>
%232 = tensor.empty() : tensor<4xf32>
%233 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel"]} ins(%225 : tensor<4xf32>) outs(%232 : tensor<4xf32>) {
^bb0(%in: f32, %out: f32):
%238 = arith.subf %cst_0, %in : f32
linalg.yield %238 : f32
} -> tensor<4xf32>
%234 = tensor.empty() : tensor<4xf32>
%235 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins(%231, %233 : tensor<4xf32>, tensor<4xf32>) outs(%234 : tensor<4xf32>) {
^bb0(%in: f32, %in_2: f32, %out: f32):
%238 = arith.mulf %in, %in_2 : f32
linalg.yield %238 : f32
} -> tensor<4xf32>
%236 = tensor.empty() : tensor<4xf32>
%237 = linalg.generic {indexing_maps = [#map, #map, #map, #map], iterator_types = ["parallel"]} ins(%227, %229, %235 : tensor<4xi1>, tensor<4xf32>, tensor<4xf32>) outs(%236 : tensor<4xf32>) {
^bb0(%in: i1, %in_2: f32, %in_3: f32, %out: f32):
%238 = arith.select %in, %in_2, %in_3 : f32
linalg.yield %238 : f32
} -> tensor<4xf32>
return %237 : tensor<4xf32>
}
}
@llandsmeer
Copy link
Author

//===------------------------------------------------------------*- C++ -*-===//
//
// Automatically generated file for High-level Synthesis (HLS).
//
//===----------------------------------------------------------------------===//
#include <hls_stream.h>
#include <hls_half.h>
#include <cassert>
#include <hls_math.h>
#include <stdint.h>
#include <stdio.h>
#include <string.h>

using namespace std;

void node0(
  hls::stream<float> &v0,
  float v1[4]
) {
  loop0: for (int v2 = 0; v2 < 4; v2++) {
    #pragma HLS pipeline II=1
    #pragma HLS loop_flatten
    loop1: for (int v3 = 0; v3 < 1; v3++) {
      float v4 = v0.read();
      v1[(v2 + v3)] = v4;
    }
  }
  return ;
}

void node1(
  float v5[4],
  hls::stream<float> &v6,
  float v7,
  float v8
) {
  loop2: for (int v9 = 0; v9 < 4; v9++) {
    #pragma HLS pipeline II=1
    #pragma HLS loop_flatten
    loop3: for (int v10 = 0; v10 < 1; v10++) {
      float v11 = v5[(v9 + v10)];
      float v12 = v11 * v8;
      float v13 = v12 * v8;
      float v14 = v7 - v13;
      float v15 = v11 * v13;
      float v16 = v15 * v14;
      float v17 = -(v13);
      bool v18 = v13 > v8;
      float v19 = v18 ? (float)v17 : (float)v16;
      float v20 = v7 - v19;
      float v21 = v11 * v19;
      float v22 = v21 * v20;
      float v23 = -(v19);
      bool v24 = v19 > v8;
      float v25 = v24 ? (float)v23 : (float)v22;
      float v26 = v7 - v25;
      float v27 = v11 * v25;
      float v28 = v27 * v26;
      float v29 = -(v25);
      bool v30 = v25 > v8;
      float v31 = v30 ? (float)v29 : (float)v28;
      float v32 = v7 - v31;
      float v33 = v11 * v31;
      float v34 = v33 * v32;
      float v35 = -(v31);
      bool v36 = v31 > v8;
      float v37 = v36 ? (float)v35 : (float)v34;
      float v38 = v7 - v37;
      float v39 = v11 * v37;
      float v40 = v39 * v38;
      float v41 = -(v37);
      bool v42 = v37 > v8;
      float v43 = v42 ? (float)v41 : (float)v40;
      float v44 = v7 - v43;
      float v45 = v11 * v43;
      float v46 = v45 * v44;
      float v47 = -(v43);
      bool v48 = v43 > v8;
      float v49 = v48 ? (float)v47 : (float)v46;
      float v50 = v7 - v49;
      float v51 = v11 * v49;
      float v52 = v51 * v50;
      float v53 = -(v49);
      bool v54 = v49 > v8;
      float v55 = v54 ? (float)v53 : (float)v52;
      float v56 = v7 - v55;
      float v57 = v11 * v55;
      float v58 = v57 * v56;
      float v59 = -(v55);
      bool v60 = v55 > v8;
      float v61 = v60 ? (float)v59 : (float)v58;
      float v62 = v7 - v61;
      float v63 = v11 * v61;
      float v64 = v63 * v62;
      float v65 = -(v61);
      bool v66 = v61 > v8;
      float v67 = v66 ? (float)v65 : (float)v64;
      float v68 = v7 - v67;
      float v69 = v11 * v67;
      float v70 = v69 * v68;
      float v71 = -(v67);
      bool v72 = v67 > v8;
      float v73 = v72 ? (float)v71 : (float)v70;
      float v74 = v7 - v73;
      float v75 = v11 * v73;
      float v76 = v75 * v74;
      float v77 = -(v73);
      bool v78 = v73 > v8;
      float v79 = v78 ? (float)v77 : (float)v76;
      float v80 = v7 - v79;
      float v81 = v11 * v79;
      float v82 = v81 * v80;
      float v83 = -(v79);
      bool v84 = v79 > v8;
      float v85 = v84 ? (float)v83 : (float)v82;
      float v86 = v7 - v85;
      float v87 = v11 * v85;
      float v88 = v87 * v86;
      float v89 = -(v85);
      bool v90 = v85 > v8;
      float v91 = v90 ? (float)v89 : (float)v88;
      float v92 = v7 - v91;
      float v93 = v11 * v91;
      float v94 = v93 * v92;
      float v95 = -(v91);
      bool v96 = v91 > v8;
      float v97 = v96 ? (float)v95 : (float)v94;
      float v98 = v7 - v97;
      float v99 = v11 * v97;
      float v100 = v99 * v98;
      float v101 = -(v97);
      bool v102 = v97 > v8;
      float v103 = v102 ? (float)v101 : (float)v100;
      float v104 = v7 - v103;
      float v105 = v11 * v103;
      float v106 = v105 * v104;
      float v107 = -(v103);
      bool v108 = v103 > v8;
      float v109 = v108 ? (float)v107 : (float)v106;
      float v110 = v7 - v109;
      float v111 = v11 * v109;
      float v112 = v111 * v110;
      float v113 = -(v109);
      bool v114 = v109 > v8;
      float v115 = v114 ? (float)v113 : (float)v112;
      float v116 = v7 - v115;
      float v117 = v11 * v115;
      float v118 = v117 * v116;
      float v119 = -(v115);
      bool v120 = v115 > v8;
      float v121 = v120 ? (float)v119 : (float)v118;
      float v122 = v7 - v121;
      float v123 = v11 * v121;
      float v124 = v123 * v122;
      float v125 = -(v121);
      bool v126 = v121 > v8;
      float v127 = v126 ? (float)v125 : (float)v124;
      v6.write(v127);
    }
  }
  return ;
}

void forward(
  float v128[4],
  float v129[4]
) {
	#pragma HLS DATAFLOW
  hls::stream<float> v130;
	#pragma HLS STREAM variable=v130 depth=4
  node1(v128, v130, 1.000000, 0.500000);
  node0(v130, v129);
  return ;
}

@SuhailB
Copy link

SuhailB commented May 14, 2025

Hi @llandsmeer,

Thank you for your interest and for pointing this out. I added support for any function name, while keeping the default as forward. However, if the kernel name is "main," then there will be conflict with the host code which has the main function. So, the input linalg.mlir top function can have any name except "main" for the flows to work.

#map = affine_map<(d0) -> (d0)>
module @jit_f {
  // renamed main -> kernel
  func.func public @kernel(%arg0: tensor<4xf32>) -> tensor<4xf32> {
    %cst = arith.constant 1.000000e+00 : f32
    %0 = tensor.empty() : tensor<4xf32>
    %1 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel"]} ins(%arg0 : tensor<4xf32>) outs(%0 : tensor<4xf32>) {
    ^bb0(%in: f32, %out: f32):
      %2 = arith.addf %in, %cst : f32
      linalg.yield %2 : f32
    } -> tensor<4xf32>
    return %1 : tensor<4xf32>
  }
}

run.sh

kernel_name=kernel

streamhls-opt ./jax.linalg.mlir \
    -streamhls-host-pipeline="top-func=$kernel_name" \ # updated
    > host.mlir

streamhls-translate ./host.mlir \
    -emit-vivado-hls \
    -vitis-hls-weights-dir=data \
    -vitis-hls-is-host=true \
    -o host_tb.cpp


streamhls-opt jax.linalg.mlir \
  -streamhls-kernel-pipeline="top-func=$kernel_name \ # updated
    graph-file=graph\
    report-file=report\
    optimize-schedule=1\
    parallelize-nodes=1\
    combined-optimization=0\
    board-dsps=1024 \
    tiling-limit=16 \
    time-limit-minutes=10 \
    bufferize-func-args=0 \
    optimize-conv-reuse=0 \
    minimize-on-chip-buffers=0 \
    debug-point=14" > kernel.mlir

streamhls-translate \
    kernel.mlir \
    -emit-vivado-hls \
    -vitis-hls-top-func=$kernel_name \ # updated
    -o kernel.cpp

expected output:

Permutation DesignSpaceSize: 1
Parallelization DesignSpaceSize: 9
Total DesignSpaceSize: 9
Permutation solver: latency: 3
Permutation DesignSpaceSize: 1
Parallelization DesignSpaceSize: 9
Total DesignSpaceSize: 9
Parallelization solver: Parallel Latency: 3
Total DSPs: 2

Please let me know if you face any other issues.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment