func @predict(%arg0: tensor<?x784xf32>) -> tensor<?x10xf32> attributes {iree.module.export, iree.reflection = {abi = "sip", abiv = 1 : i32, sip = "I8!S5!k0_0R3!_0"}, tf._input_shapes = [#tf.shape<?x784>, #tf.shape<*>, #tf.shape<*>, #tf.shape<*>, #tf.shape<*>, #tf.shape<*>, #tf.shape<*>], tf.signature.is_stateful} {
%0 = flow.variable.address @__iree_flow___sm_node1__h1_weights : !iree.ptr<tensor<784x256xf32>>
%1 = flow.variable.address @__iree_flow___sm_node4__h1_bias : !iree.ptr<tensor<256xf32>>
%2 = flow.variable.address @__iree_flow___sm_node2__h2_weights : !iree.ptr<tensor<256x256xf32>>
%3 = flow.variable.address @__iree_flow___sm_node5__h2_bias : !iree.ptr<tensor<256xf32>>
%4 = flow.variable.address @__iree_flow___sm_node3__out_weights : !iree.ptr<tensor<256x10xf32>>
%5 = flow.variable.address @__iree_flow___sm_node6__out_bias : !iree.ptr<tensor<10xf32>>
%rs256 = shapex.const_ranked_shape : !shapex.ranked_shape<[256]>
%rs10 = shapex.const_ranked_shape : !shapex.ranked_shape<[10]>
%6 = xla_hlo.constant dense<5.000000e-01> : tensor<f32>
%7 = xla_hlo.constant dense<0xFF800000> : tensor<f32>
%8 = xla_hlo.constant dense<0.000000e+00> : tensor<f32>
%9 = flow.variable.load.indirect %3 : !iree.ptr<tensor<256xf32>> -> tensor<256xf32>
%10 = flow.variable.load.indirect %5 : !iree.ptr<tensor<10xf32>> -> tensor<10xf32>
%11 = flow.variable.load.indirect %1 : !iree.ptr<tensor<256xf32>> -> tensor<256xf32>
%12 = flow.variable.load.indirect %2 : !iree.ptr<tensor<256x256xf32>> -> tensor<256x256xf32>
%13 = flow.variable.load.indirect %4 : !iree.ptr<tensor<256x10xf32>> -> tensor<256x10xf32>
%14 = flow.variable.load.indirect %0 : !iree.ptr<tensor<784x256xf32>> -> tensor<784x256xf32>
%15 = "xla_hlo.dot"(%arg0, %14) : (tensor<?x784xf32>, tensor<784x256xf32>) -> tensor<?x256xf32>
%16 = shapex.get_ranked_shape %15 : tensor<?x256xf32> -> !shapex.ranked_shape<[?,256]>
%17 = "shapex.ranked_broadcast_shape"(%rs256, %16) {lhs_broadcast_dimensions = dense<1> : tensor<1xi64>, rhs_broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (!shapex.ranked_shape<[256]>, !shapex.ranked_shape<[?,256]>) -> !shapex.ranked_shape<[?,256]>
%18 = "shapex.ranked_broadcast_in_dim"(%15, %17) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<?x256xf32>, !shapex.ranked_shape<[?,256]>) -> tensor<?x256xf32>
%19 = "shapex.ranked_broadcast_in_dim"(%11, %17) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<256xf32>, !shapex.ranked_shape<[?,256]>) -> tensor<?x256xf32>
%20 = xla_hlo.add %18, %19 : tensor<?x256xf32>
%21 = shapex.get_ranked_shape %20 : tensor<?x256xf32> -> !shapex.ranked_shape<[?,256]>
%22 = "shapex.ranked_broadcast_in_dim"(%6, %21) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<f32>, !shapex.ranked_shape<[?,256]>) -> tensor<?x256xf32>
%23 = xla_hlo.multiply %20, %22 : tensor<?x256xf32>
%24 = "xla_hlo.tanh"(%23) : (tensor<?x256xf32>) -> tensor<?x256xf32>
%25 = xla_hlo.multiply %24, %22 : tensor<?x256xf32>
%26 = xla_hlo.add %25, %22 : tensor<?x256xf32>
%27 = "xla_hlo.dot"(%26, %12) : (tensor<?x256xf32>, tensor<256x256xf32>) -> tensor<?x256xf32>
%28 = shapex.get_ranked_shape %27 : tensor<?x256xf32> -> !shapex.ranked_shape<[?,256]>
%29 = "shapex.ranked_broadcast_shape"(%rs256, %28) {lhs_broadcast_dimensions = dense<1> : tensor<1xi64>, rhs_broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (!shapex.ranked_shape<[256]>, !shapex.ranked_shape<[?,256]>) -> !shapex.ranked_shape<[?,256]>
%30 = "shapex.ranked_broadcast_in_dim"(%27, %29) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<?x256xf32>, !shapex.ranked_shape<[?,256]>) -> tensor<?x256xf32>
%31 = "shapex.ranked_broadcast_in_dim"(%9, %29) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<256xf32>, !shapex.ranked_shape<[?,256]>) -> tensor<?x256xf32>
%32 = xla_hlo.add %30, %31 : tensor<?x256xf32>
%33 = shapex.get_ranked_shape %32 : tensor<?x256xf32> -> !shapex.ranked_shape<[?,256]>
%34 = "shapex.ranked_broadcast_in_dim"(%6, %33) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<f32>, !shapex.ranked_shape<[?,256]>) -> tensor<?x256xf32>
%35 = xla_hlo.multiply %32, %34 : tensor<?x256xf32>
%36 = "xla_hlo.tanh"(%35) : (tensor<?x256xf32>) -> tensor<?x256xf32>
%37 = xla_hlo.multiply %36, %34 : tensor<?x256xf32>
%38 = xla_hlo.add %37, %34 : tensor<?x256xf32>
%39 = "xla_hlo.dot"(%38, %13) : (tensor<?x256xf32>, tensor<256x10xf32>) -> tensor<?x10xf32>
%40 = shapex.get_ranked_shape %39 : tensor<?x10xf32> -> !shapex.ranked_shape<[?,10]>
%41 = "shapex.ranked_broadcast_shape"(%rs10, %40) {lhs_broadcast_dimensions = dense<1> : tensor<1xi64>, rhs_broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (!shapex.ranked_shape<[10]>, !shapex.ranked_shape<[?,10]>) -> !shapex.ranked_shape<[?,10]>
%42 = "shapex.ranked_broadcast_in_dim"(%39, %41) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<?x10xf32>, !shapex.ranked_shape<[?,10]>) -> tensor<?x10xf32>
%43 = "shapex.ranked_broadcast_in_dim"(%10, %41) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<10xf32>, !shapex.ranked_shape<[?,10]>) -> tensor<?x10xf32>
%44 = xla_hlo.add %42, %43 : tensor<?x10xf32>
%45 = shapex.get_ranked_shape %44 : tensor<?x10xf32> -> !shapex.ranked_shape<[?,10]>
%46 = "shapex.ranked_broadcast_in_dim"(%6, %45) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<f32>, !shapex.ranked_shape<[?,10]>) -> tensor<?x10xf32>
%47 = xla_hlo.multiply %44, %46 : tensor<?x10xf32>
%48 = "xla_hlo.tanh"(%47) : (tensor<?x10xf32>) -> tensor<?x10xf32>
%49 = xla_hlo.multiply %48, %46 : tensor<?x10xf32>
%50 = xla_hlo.add %49, %46 : tensor<?x10xf32>
%51 = "xla_hlo.reduce"(%50, %7) ( {
^bb0(%arg1: tensor<f32>, %arg2: tensor<f32>): // no predecessors
%60 = xla_hlo.maximum %arg1, %arg2 : tensor<f32>
"xla_hlo.return"(%60) : (tensor<f32>) -> ()
}) {dimensions = dense<1> : tensor<1xi64>} : (tensor<?x10xf32>, tensor<f32>) -> tensor<?xf32>
%52 = shapex.get_ranked_shape %50 : tensor<?x10xf32> -> !shapex.ranked_shape<[?,10]>
%53 = "shapex.ranked_broadcast_in_dim"(%51, %52) {broadcast_dimensions = dense<0> : tensor<1xi64>} : (tensor<?xf32>, !shapex.ranked_shape<[?,10]>) -> tensor<?x10xf32>
%54 = xla_hlo.subtract %50, %53 : tensor<?x10xf32>
%55 = "xla_hlo.exponential"(%54) : (tensor<?x10xf32>) -> tensor<?x10xf32>
%56 = "xla_hlo.reduce"(%55, %8) ( {
^bb0(%arg1: tensor<f32>, %arg2: tensor<f32>): // no predecessors
%60 = xla_hlo.add %arg1, %arg2 : tensor<f32>
"xla_hlo.return"(%60) : (tensor<f32>) -> ()
}) {dimensions = dense<1> : tensor<1xi64>} : (tensor<?x10xf32>, tensor<f32>) -> tensor<?xf32>
%57 = shapex.get_ranked_shape %50 : tensor<?x10xf32> -> !shapex.ranked_shape<[?,10]>
%58 = "shapex.ranked_broadcast_in_dim"(%56, %57) {broadcast_dimensions = dense<0> : tensor<1xi64>} : (tensor<?xf32>, !shapex.ranked_shape<[?,10]>) -> tensor<?x10xf32>
%59 = xla_hlo.divide %55, %58 : tensor<?x10xf32>
return %59 : tensor<?x10xf32>
}Compiles to vulkan-spirv vm module.
iree-opt -mlir-disable-threading -iree-hal-target-backends=vulkan-spirv -pass-pipeline='canonicalize, func(xla-legalize-control-flow, ire e-flow-hlo-to-hlo-preprocessing), convert-shape-to-shapex, iree-flow-flatten-tuples-in-cfg, inline{disable-simplify=false max-iterations=4}, func(canonicalize , cse), iree-flow-legalize-input-types, func(iree-flow-materialize-exported-reflection, iree-shape-expand-function-dynamic-dims, iree-flow-merge-exported-refl ection, iree-shape-tie-dynamic, iree-shape-materialize-calculations, iree-shape-hoist-shape-calculations, iree-flow-pre-partitioning-conversion), iree-flow-di spatchability-analysis, func(iree-flow-identify-dispatch-regions2, canonicalize, cse, canonicalize, iree-flow-rematerialize-dispatch-constants), iree-flow-out line-dispatch-regions, func(canonicalize, iree-flow-post-partitioning-conversion, canonicalize, cse, iree-flow-hoist-unstreamable-ops, canonicalize, iree-flow -form-streams, canonicalize, cse), symbol-dce, canonicalize, iree-hal-materialize-interfaces, hal.executable(iree-hal-translate-executables), iree-hal-link-ex ecutables, iree-convert-flow-to-hal, func(iree-shape-expand-function-ranked-shape-dims, canonicalize, cse), iree-hal-public-abi-generation, iree-hal-materiali ze-resource-caches, func(iree-hal-inline-device-switches), iree-hal-memoize-device-queries, func(canonicalize, cse), hal.executable(iree-hal-serialize-executa bles), symbol-dce, canonicalize, iree-vm-conversion, vm.module(iree-vm-global-initialization), inline{disable-simplify=false max-iterations=4}, cse, symbol-dc e' --mlir-elide-elementsattrs-if-larger=100 ../data/mlp_reproducer.mlir// *** IR Dump After mlir::iree_compiler::Shape::`anonymous-namespace'::MaterializeShapeCalculationsPass ***
func @predict(%arg0: tensor<?x784xf32> {iree.reflection = {}}, %arg1: !shapex.ranked_shape<[?,784]> {iree.reflection = {}}) -> (tensor<?x10xf32> {iree.reflection = {}}, !shapex.ranked_shape<[?,10]> {iree.reflection = {}}) attributes {iree.module.export, iree.reflection = {abi = "sip", abiv = 1 : i32, f = "I11!B8!d-1d784R10!B7!d-1d10", fv = "1", sip = "I8!S5!k0_0R3!_0"}, tf._input_shapes = [#tf.shape<?x784>, #tf.shape<*>, #tf.shape<*>, #tf.shape<*>, #tf.shape<*>, #tf.shape<*>, #tf.shape<*>], tf.signature.is_stateful} {
%0 = xla_hlo.constant dense<5.000000e-01> : tensor<f32>
%1 = xla_hlo.constant dense<0xFF800000> : tensor<f32>
%2 = xla_hlo.constant dense<0.000000e+00> : tensor<f32>
%cst = constant opaque<"", "0xDEADBEEF"> : tensor<256xf32>
%cst_0 = constant dense<[-1.1745497, -6.830580e-01, -0.22088325, 1.9310919, -0.511352897, -0.557739139, -0.712474167, 0.261824518, 0.724062442, 1.56302774]> : tensor<10xf32>
%cst_1 = constant opaque<"", "0xDEADBEEF"> : tensor<256xf32>
%cst_2 = constant opaque<"", "0xDEADBEEF"> : tensor<256x256xf32>
%cst_3 = constant opaque<"", "0xDEADBEEF"> : tensor<256x10xf32>
%cst_4 = constant opaque<"", "0xDEADBEEF"> : tensor<784x256xf32>
%c1 = constant 1 : index
%3 = shapex.tie_shape %arg0, %arg1 : tensor<?x784xf32>, !shapex.ranked_shape<[?,784]>
%4 = "xla_hlo.dot"(%3, %cst_4) : (tensor<?x784xf32>, tensor<784x256xf32>) -> tensor<?x256xf32>
%5 = shapex.ranked_dim %arg1[0] : !shapex.ranked_shape<[?,784]> -> index
%6 = shapex.make_ranked_shape %5 : (index) -> !shapex.ranked_shape<[?,256]>
%7 = shapex.tie_shape %4, %6 : tensor<?x256xf32>, !shapex.ranked_shape<[?,256]>
%8 = cmpi "ugt", %c1, %5 : index
%9 = select %8, %c1, %5 : index
%10 = shapex.make_ranked_shape %9 : (index) -> !shapex.ranked_shape<[?,256]>
%11 = "shapex.ranked_broadcast_in_dim"(%7, %10) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<?x256xf32>, !shapex.ranked_shape<[?,256]>) -> tensor<?x256xf32>
%12 = shapex.tie_shape %11, %10 : tensor<?x256xf32>, !shapex.ranked_shape<[?,256]>
%13 = "shapex.ranked_broadcast_in_dim"(%cst_1, %10) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<256xf32>, !shapex.ranked_shape<[?,256]>) -> tensor<?x256xf32>
%14 = shapex.tie_shape %13, %10 : tensor<?x256xf32>, !shapex.ranked_shape<[?,256]>
%15 = xla_hlo.add %12, %14 : tensor<?x256xf32>
%16 = shapex.tie_shape %15, %10 : tensor<?x256xf32>, !shapex.ranked_shape<[?,256]>
%17 = "shapex.ranked_broadcast_in_dim"(%0, %10) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<f32>, !shapex.ranked_shape<[?,256]>) -> tensor<?x256xf32>
%18 = shapex.tie_shape %17, %10 : tensor<?x256xf32>, !shapex.ranked_shape<[?,256]>
%19 = xla_hlo.multiply %16, %18 : tensor<?x256xf32>
%20 = shapex.tie_shape %19, %10 : tensor<?x256xf32>, !shapex.ranked_shape<[?,256]>
%21 = "xla_hlo.tanh"(%20) : (tensor<?x256xf32>) -> tensor<?x256xf32>
%22 = shapex.tie_shape %21, %10 : tensor<?x256xf32>, !shapex.ranked_shape<[?,256]>
%23 = xla_hlo.multiply %22, %18 : tensor<?x256xf32>
%24 = shapex.tie_shape %23, %10 : tensor<?x256xf32>, !shapex.ranked_shape<[?,256]>
%25 = xla_hlo.add %24, %18 : tensor<?x256xf32>
%26 = shapex.tie_shape %25, %10 : tensor<?x256xf32>, !shapex.ranked_shape<[?,256]>
%27 = "xla_hlo.dot"(%26, %cst_2) : (tensor<?x256xf32>, tensor<256x256xf32>) -> tensor<?x256xf32>
%28 = shapex.tie_shape %27, %10 : tensor<?x256xf32>, !shapex.ranked_shape<[?,256]>
%29 = cmpi "ugt", %c1, %9 : index
%30 = select %29, %c1, %9 : index
%31 = shapex.make_ranked_shape %30 : (index) -> !shapex.ranked_shape<[?,256]>
%32 = "shapex.ranked_broadcast_in_dim"(%28, %31) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<?x256xf32>, !shapex.ranked_shape<[?,256]>) -> tensor<?x256xf32>
%33 = shapex.tie_shape %32, %31 : tensor<?x256xf32>, !shapex.ranked_shape<[?,256]>
%34 = "shapex.ranked_broadcast_in_dim"(%cst, %31) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<256xf32>, !shapex.ranked_shape<[?,256]>) -> tensor<?x256xf32>
%35 = shapex.tie_shape %34, %31 : tensor<?x256xf32>, !shapex.ranked_shape<[?,256]>
%36 = xla_hlo.add %33, %35 : tensor<?x256xf32>
%37 = shapex.tie_shape %36, %31 : tensor<?x256xf32>, !shapex.ranked_shape<[?,256]>
%38 = "shapex.ranked_broadcast_in_dim"(%0, %31) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<f32>, !shapex.ranked_shape<[?,256]>) -> tensor<?x256xf32>
%39 = shapex.tie_shape %38, %31 : tensor<?x256xf32>, !shapex.ranked_shape<[?,256]>
%40 = xla_hlo.multiply %37, %39 : tensor<?x256xf32>
%41 = shapex.tie_shape %40, %31 : tensor<?x256xf32>, !shapex.ranked_shape<[?,256]>
%42 = "xla_hlo.tanh"(%41) : (tensor<?x256xf32>) -> tensor<?x256xf32>
%43 = shapex.tie_shape %42, %31 : tensor<?x256xf32>, !shapex.ranked_shape<[?,256]>
%44 = xla_hlo.multiply %43, %39 : tensor<?x256xf32>
%45 = shapex.tie_shape %44, %31 : tensor<?x256xf32>, !shapex.ranked_shape<[?,256]>
%46 = xla_hlo.add %45, %39 : tensor<?x256xf32>
%47 = shapex.tie_shape %46, %31 : tensor<?x256xf32>, !shapex.ranked_shape<[?,256]>
%48 = "xla_hlo.dot"(%47, %cst_3) : (tensor<?x256xf32>, tensor<256x10xf32>) -> tensor<?x10xf32>
%49 = shapex.make_ranked_shape %30 : (index) -> !shapex.ranked_shape<[?,10]>
%50 = shapex.tie_shape %48, %49 : tensor<?x10xf32>, !shapex.ranked_shape<[?,10]>
%51 = cmpi "ugt", %c1, %30 : index
%52 = select %51, %c1, %30 : index
%53 = shapex.make_ranked_shape %52 : (index) -> !shapex.ranked_shape<[?,10]>
%54 = "shapex.ranked_broadcast_in_dim"(%50, %53) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<?x10xf32>, !shapex.ranked_shape<[?,10]>) -> tensor<?x10xf32>
%55 = shapex.tie_shape %54, %53 : tensor<?x10xf32>, !shapex.ranked_shape<[?,10]>
%56 = "shapex.ranked_broadcast_in_dim"(%cst_0, %53) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<10xf32>, !shapex.ranked_shape<[?,10]>) -> tensor<?x10xf32>
%57 = shapex.tie_shape %56, %53 : tensor<?x10xf32>, !shapex.ranked_shape<[?,10]>
%58 = xla_hlo.add %55, %57 : tensor<?x10xf32>
%59 = shapex.tie_shape %58, %53 : tensor<?x10xf32>, !shapex.ranked_shape<[?,10]>
%60 = "shapex.ranked_broadcast_in_dim"(%0, %53) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<f32>, !shapex.ranked_shape<[?,10]>) -> tensor<?x10xf32>
%61 = shapex.tie_shape %60, %53 : tensor<?x10xf32>, !shapex.ranked_shape<[?,10]>
%62 = xla_hlo.multiply %59, %61 : tensor<?x10xf32>
%63 = shapex.tie_shape %62, %53 : tensor<?x10xf32>, !shapex.ranked_shape<[?,10]>
%64 = "xla_hlo.tanh"(%63) : (tensor<?x10xf32>) -> tensor<?x10xf32>
%65 = shapex.tie_shape %64, %53 : tensor<?x10xf32>, !shapex.ranked_shape<[?,10]>
%66 = xla_hlo.multiply %65, %61 : tensor<?x10xf32>
%67 = shapex.tie_shape %66, %53 : tensor<?x10xf32>, !shapex.ranked_shape<[?,10]>
%68 = xla_hlo.add %67, %61 : tensor<?x10xf32>
%69 = shapex.tie_shape %68, %53 : tensor<?x10xf32>, !shapex.ranked_shape<[?,10]>
%70 = "xla_hlo.reduce"(%69, %1) ( {
^bb0(%arg2: tensor<f32>, %arg3: tensor<f32>): // no predecessors
%86 = xla_hlo.maximum %arg2, %arg3 : tensor<f32>
"xla_hlo.return"(%86) : (tensor<f32>) -> ()
}) {dimensions = dense<1> : tensor<1xi64>} : (tensor<?x10xf32>, tensor<f32>) -> tensor<?xf32>
%71 = shapex.make_ranked_shape %52 : (index) -> !shapex.ranked_shape<[?]>
%72 = shapex.tie_shape %70, %71 : tensor<?xf32>, !shapex.ranked_shape<[?]>
%73 = "shapex.ranked_broadcast_in_dim"(%72, %53) {broadcast_dimensions = dense<0> : tensor<1xi64>} : (tensor<?xf32>, !shapex.ranked_shape<[?,10]>) -> tensor<?x10xf32>
%74 = shapex.tie_shape %73, %53 : tensor<?x10xf32>, !shapex.ranked_shape<[?,10]>
%75 = xla_hlo.subtract %69, %74 : tensor<?x10xf32>
%76 = shapex.tie_shape %75, %53 : tensor<?x10xf32>, !shapex.ranked_shape<[?,10]>
%77 = "xla_hlo.exponential"(%76) : (tensor<?x10xf32>) -> tensor<?x10xf32>
%78 = shapex.tie_shape %77, %53 : tensor<?x10xf32>, !shapex.ranked_shape<[?,10]>
%79 = "xla_hlo.reduce"(%78, %2) ( {
^bb0(%arg2: tensor<f32>, %arg3: tensor<f32>): // no predecessors
%86 = xla_hlo.add %arg2, %arg3 : tensor<f32>
"xla_hlo.return"(%86) : (tensor<f32>) -> ()
}) {dimensions = dense<1> : tensor<1xi64>} : (tensor<?x10xf32>, tensor<f32>) -> tensor<?xf32>
%80 = shapex.make_ranked_shape %52 : (index) -> !shapex.ranked_shape<[?]>
%81 = shapex.tie_shape %79, %80 : tensor<?xf32>, !shapex.ranked_shape<[?]>
%82 = "shapex.ranked_broadcast_in_dim"(%81, %53) {broadcast_dimensions = dense<0> : tensor<1xi64>} : (tensor<?xf32>, !shapex.ranked_shape<[?,10]>) -> tensor<?x10xf32>
%83 = shapex.tie_shape %82, %53 : tensor<?x10xf32>, !shapex.ranked_shape<[?,10]>
%84 = xla_hlo.divide %78, %83 : tensor<?x10xf32>
%85 = shapex.tie_shape %84, %53 : tensor<?x10xf32>, !shapex.ranked_shape<[?,10]>
return %85, %53 : tensor<?x10xf32>, !shapex.ranked_shape<[?,10]>
}// *** IR Dump After mlir::iree_compiler::Shape::`anonymous-namespace'::HoistShapeCalculations ***
func @predict(%arg0: tensor<?x784xf32> {iree.reflection = {}}, %arg1: !shapex.ranked_shape<[?,784]> {iree.reflection = {}}) -> (tensor<?x10xf32> {iree.reflection = {}}, !shapex.ranked_shape<[?,10]> {iree.reflection = {}}) attributes {iree.module.export, iree.reflection = {abi = "sip", abiv = 1 : i32, f = "I11!B8!d-1d784R10!B7!d-1d10", fv = "1", sip = "I8!S5!k0_0R3!_0"}, tf._input_shapes = [#tf.shape<?x784>, #tf.shape<*>, #tf.shape<*>, #tf.shape<*>, #tf.shape<*>, #tf.shape<*>, #tf.shape<*>], tf.signature.is_stateful} {
%0 = shapex.ranked_dim %arg1[0] : !shapex.ranked_shape<[?,784]> -> index
%1 = shapex.make_ranked_shape %0 : (index) -> !shapex.ranked_shape<[?,256]>
%c1 = constant 1 : index
%2 = cmpi "ugt", %c1, %0 : index
%3 = select %2, %c1, %0 : index
%4 = cmpi "ugt", %c1, %3 : index
%5 = select %4, %c1, %3 : index
%6 = cmpi "ugt", %c1, %5 : index
%7 = select %6, %c1, %5 : index
%8 = shapex.make_ranked_shape %7 : (index) -> !shapex.ranked_shape<[?]>
%9 = shapex.make_ranked_shape %7 : (index) -> !shapex.ranked_shape<[?]>
%10 = shapex.make_ranked_shape %7 : (index) -> !shapex.ranked_shape<[?,10]>
%11 = shapex.make_ranked_shape %5 : (index) -> !shapex.ranked_shape<[?,10]>
%12 = shapex.make_ranked_shape %5 : (index) -> !shapex.ranked_shape<[?,256]>
%13 = shapex.make_ranked_shape %3 : (index) -> !shapex.ranked_shape<[?,256]>
%14 = xla_hlo.constant dense<5.000000e-01> : tensor<f32>
%15 = xla_hlo.constant dense<0xFF800000> : tensor<f32>
%16 = xla_hlo.constant dense<0.000000e+00> : tensor<f32>
%cst = constant opaque<"", "0xDEADBEEF"> : tensor<256xf32>
%cst_0 = constant dense<[-1.1745497, -6.830580e-01, -0.22088325, 1.9310919, -0.511352897, -0.557739139, -0.712474167, 0.261824518, 0.724062442, 1.56302774]> : tensor<10xf32>
%cst_1 = constant opaque<"", "0xDEADBEEF"> : tensor<256xf32>
%cst_2 = constant opaque<"", "0xDEADBEEF"> : tensor<256x256xf32>
%cst_3 = constant opaque<"", "0xDEADBEEF"> : tensor<256x10xf32>
%cst_4 = constant opaque<"", "0xDEADBEEF"> : tensor<784x256xf32>
%17 = shapex.tie_shape %arg0, %arg1 : tensor<?x784xf32>, !shapex.ranked_shape<[?,784]>
%18 = "xla_hlo.dot"(%17, %cst_4) : (tensor<?x784xf32>, tensor<784x256xf32>) -> tensor<?x256xf32>
%19 = shapex.tie_shape %18, %1 : tensor<?x256xf32>, !shapex.ranked_shape<[?,256]>
%20 = "shapex.ranked_broadcast_in_dim"(%19, %13) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<?x256xf32>, !shapex.ranked_shape<[?,256]>) -> tensor<?x256xf32>
%21 = shapex.tie_shape %20, %13 : tensor<?x256xf32>, !shapex.ranked_shape<[?,256]>
%22 = "shapex.ranked_broadcast_in_dim"(%cst_1, %13) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<256xf32>, !shapex.ranked_shape<[?,256]>) -> tensor<?x256xf32>
%23 = shapex.tie_shape %22, %13 : tensor<?x256xf32>, !shapex.ranked_shape<[?,256]>
%24 = xla_hlo.add %21, %23 : tensor<?x256xf32>
%25 = shapex.tie_shape %24, %13 : tensor<?x256xf32>, !shapex.ranked_shape<[?,256]>
%26 = "shapex.ranked_broadcast_in_dim"(%14, %13) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<f32>, !shapex.ranked_shape<[?,256]>) -> tensor<?x256xf32>
%27 = shapex.tie_shape %26, %13 : tensor<?x256xf32>, !shapex.ranked_shape<[?,256]>
%28 = xla_hlo.multiply %25, %27 : tensor<?x256xf32>
%29 = shapex.tie_shape %28, %13 : tensor<?x256xf32>, !shapex.ranked_shape<[?,256]>
%30 = "xla_hlo.tanh"(%29) : (tensor<?x256xf32>) -> tensor<?x256xf32>
%31 = shapex.tie_shape %30, %13 : tensor<?x256xf32>, !shapex.ranked_shape<[?,256]>
%32 = xla_hlo.multiply %31, %27 : tensor<?x256xf32>
%33 = shapex.tie_shape %32, %13 : tensor<?x256xf32>, !shapex.ranked_shape<[?,256]>
%34 = xla_hlo.add %33, %27 : tensor<?x256xf32>
%35 = shapex.tie_shape %34, %13 : tensor<?x256xf32>, !shapex.ranked_shape<[?,256]>
%36 = "xla_hlo.dot"(%35, %cst_2) : (tensor<?x256xf32>, tensor<256x256xf32>) -> tensor<?x256xf32>
%37 = shapex.tie_shape %36, %13 : tensor<?x256xf32>, !shapex.ranked_shape<[?,256]>
%38 = "shapex.ranked_broadcast_in_dim"(%37, %12) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<?x256xf32>, !shapex.ranked_shape<[?,256]>) -> tensor<?x256xf32>
%39 = shapex.tie_shape %38, %12 : tensor<?x256xf32>, !shapex.ranked_shape<[?,256]>
%40 = "shapex.ranked_broadcast_in_dim"(%cst, %12) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<256xf32>, !shapex.ranked_shape<[?,256]>) -> tensor<?x256xf32>
%41 = shapex.tie_shape %40, %12 : tensor<?x256xf32>, !shapex.ranked_shape<[?,256]>
%42 = xla_hlo.add %39, %41 : tensor<?x256xf32>
%43 = shapex.tie_shape %42, %12 : tensor<?x256xf32>, !shapex.ranked_shape<[?,256]>
%44 = "shapex.ranked_broadcast_in_dim"(%14, %12) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<f32>, !shapex.ranked_shape<[?,256]>) -> tensor<?x256xf32>
%45 = shapex.tie_shape %44, %12 : tensor<?x256xf32>, !shapex.ranked_shape<[?,256]>
%46 = xla_hlo.multiply %43, %45 : tensor<?x256xf32>
%47 = shapex.tie_shape %46, %12 : tensor<?x256xf32>, !shapex.ranked_shape<[?,256]>
%48 = "xla_hlo.tanh"(%47) : (tensor<?x256xf32>) -> tensor<?x256xf32>
%49 = shapex.tie_shape %48, %12 : tensor<?x256xf32>, !shapex.ranked_shape<[?,256]>
%50 = xla_hlo.multiply %49, %45 : tensor<?x256xf32>
%51 = shapex.tie_shape %50, %12 : tensor<?x256xf32>, !shapex.ranked_shape<[?,256]>
%52 = xla_hlo.add %51, %45 : tensor<?x256xf32>
%53 = shapex.tie_shape %52, %12 : tensor<?x256xf32>, !shapex.ranked_shape<[?,256]>
%54 = "xla_hlo.dot"(%53, %cst_3) : (tensor<?x256xf32>, tensor<256x10xf32>) -> tensor<?x10xf32>
%55 = shapex.tie_shape %54, %11 : tensor<?x10xf32>, !shapex.ranked_shape<[?,10]>
%56 = "shapex.ranked_broadcast_in_dim"(%55, %10) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<?x10xf32>, !shapex.ranked_shape<[?,10]>) -> tensor<?x10xf32>
%57 = shapex.tie_shape %56, %10 : tensor<?x10xf32>, !shapex.ranked_shape<[?,10]>
%58 = "shapex.ranked_broadcast_in_dim"(%cst_0, %10) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<10xf32>, !shapex.ranked_shape<[?,10]>) -> tensor<?x10xf32>
%59 = shapex.tie_shape %58, %10 : tensor<?x10xf32>, !shapex.ranked_shape<[?,10]>
%60 = xla_hlo.add %57, %59 : tensor<?x10xf32>
%61 = shapex.tie_shape %60, %10 : tensor<?x10xf32>, !shapex.ranked_shape<[?,10]>
%62 = "shapex.ranked_broadcast_in_dim"(%14, %10) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<f32>, !shapex.ranked_shape<[?,10]>) -> tensor<?x10xf32>
%63 = shapex.tie_shape %62, %10 : tensor<?x10xf32>, !shapex.ranked_shape<[?,10]>
%64 = xla_hlo.multiply %61, %63 : tensor<?x10xf32>
%65 = shapex.tie_shape %64, %10 : tensor<?x10xf32>, !shapex.ranked_shape<[?,10]>
%66 = "xla_hlo.tanh"(%65) : (tensor<?x10xf32>) -> tensor<?x10xf32>
%67 = shapex.tie_shape %66, %10 : tensor<?x10xf32>, !shapex.ranked_shape<[?,10]>
%68 = xla_hlo.multiply %67, %63 : tensor<?x10xf32>
%69 = shapex.tie_shape %68, %10 : tensor<?x10xf32>, !shapex.ranked_shape<[?,10]>
%70 = xla_hlo.add %69, %63 : tensor<?x10xf32>
%71 = shapex.tie_shape %70, %10 : tensor<?x10xf32>, !shapex.ranked_shape<[?,10]>
%72 = "xla_hlo.reduce"(%71, %15) ( {
^bb0(%arg2: tensor<f32>, %arg3: tensor<f32>): // no predecessors
%86 = xla_hlo.maximum %arg2, %arg3 : tensor<f32>
"xla_hlo.return"(%86) : (tensor<f32>) -> ()
}) {dimensions = dense<1> : tensor<1xi64>} : (tensor<?x10xf32>, tensor<f32>) -> tensor<?xf32>
%73 = shapex.tie_shape %72, %9 : tensor<?xf32>, !shapex.ranked_shape<[?]>
%74 = "shapex.ranked_broadcast_in_dim"(%73, %10) {broadcast_dimensions = dense<0> : tensor<1xi64>} : (tensor<?xf32>, !shapex.ranked_shape<[?,10]>) -> tensor<?x10xf32>
%75 = shapex.tie_shape %74, %10 : tensor<?x10xf32>, !shapex.ranked_shape<[?,10]>
%76 = xla_hlo.subtract %71, %75 : tensor<?x10xf32>
%77 = shapex.tie_shape %76, %10 : tensor<?x10xf32>, !shapex.ranked_shape<[?,10]>
%78 = "xla_hlo.exponential"(%77) : (tensor<?x10xf32>) -> tensor<?x10xf32>
%79 = shapex.tie_shape %78, %10 : tensor<?x10xf32>, !shapex.ranked_shape<[?,10]>
%80 = "xla_hlo.reduce"(%79, %16) ( {
^bb0(%arg2: tensor<f32>, %arg3: tensor<f32>): // no predecessors
%86 = xla_hlo.add %arg2, %arg3 : tensor<f32>
"xla_hlo.return"(%86) : (tensor<f32>) -> ()
}) {dimensions = dense<1> : tensor<1xi64>} : (tensor<?x10xf32>, tensor<f32>) -> tensor<?xf32>
%81 = shapex.tie_shape %80, %8 : tensor<?xf32>, !shapex.ranked_shape<[?]>
%82 = "shapex.ranked_broadcast_in_dim"(%81, %10) {broadcast_dimensions = dense<0> : tensor<1xi64>} : (tensor<?xf32>, !shapex.ranked_shape<[?,10]>) -> tensor<?x10xf32>
%83 = shapex.tie_shape %82, %10 : tensor<?x10xf32>, !shapex.ranked_shape<[?,10]>
%84 = xla_hlo.divide %79, %83 : tensor<?x10xf32>
%85 = shapex.tie_shape %84, %10 : tensor<?x10xf32>, !shapex.ranked_shape<[?,10]>
return %85, %10 : tensor<?x10xf32>, !shapex.ranked_shape<[?,10]>
}This has a patch in it that disables fusion of anything with dynamic shapes, which works around fusion issues on the backend (see later).
Note that this as after a hacky canonicalize-cse-canonicalize because it is hard to do the tie_shape gymnastics in one go.
The phase ordering issue here related to shapes is that forming the dispatch regions around anchor ("hero" in XLA parlance) ops requires reasoning about the "workload", which further requires reasoning about the shape in terms of calculations operating on actual dimension values. Having access to this as materialized SSA values seems important at this phase.
// *** IR Dump After Canonicalizer ***
func @predict(%arg0: tensor<?x784xf32> {iree.reflection = {}}, %arg1: !shapex.ranked_shape<[?,784]> {iree.reflection = {}}) -> (tensor<?x10xf32> {iree.reflection = {}}, !shapex.ranked_shape<[?,10]> {iree.reflection = {}}) attributes {iree.module.export, iree.reflection = {abi = "sip", abiv = 1 : i32, f = "I11!B8!d-1d784R10!B7!d-1d10", fv = "1", sip = "I8!S5!k0_0R3!_0"}, tf._input_shapes = [#tf.shape<?x784>, #tf.shape<*>, #tf.shape<*>, #tf.shape<*>, #tf.shape<*>, #tf.shape<*>, #tf.shape<*>], tf.signature.is_stateful} {
%c1 = constant 1 : index
%c10 = constant 10 : index
%c256 = constant 256 : index
%cst = constant dense<5.000000e-01> : tensor<f32>
%cst_0 = constant dense<0xFF800000> : tensor<f32>
%cst_1 = constant dense<0.000000e+00> : tensor<f32>
%cst_2 = constant opaque<"", "0xDEADBEEF"> : tensor<256xf32>
%cst_3 = constant dense<[-1.1745497, -6.830580e-01, -0.22088325, 1.9310919, -0.511352897, -0.557739139, -0.712474167, 0.261824518, 0.724062442, 1.56302774]> : tensor<10xf32>
%cst_4 = constant opaque<"", "0xDEADBEEF"> : tensor<256xf32>
%cst_5 = constant opaque<"", "0xDEADBEEF"> : tensor<256x256xf32>
%cst_6 = constant opaque<"", "0xDEADBEEF"> : tensor<256x10xf32>
%cst_7 = constant opaque<"", "0xDEADBEEF"> : tensor<784x256xf32>
%0 = shapex.ranked_dim %arg1[0] : !shapex.ranked_shape<[?,784]> -> index
%1 = shapex.make_ranked_shape %0 : (index) -> !shapex.ranked_shape<[?,256]>
%2 = muli %0, %c256 : index
%3 = cmpi "ugt", %c1, %0 : index
%4 = select %3, %c1, %0 : index
%5 = cmpi "ugt", %c1, %4 : index
%6 = select %5, %c1, %4 : index
%7 = cmpi "ugt", %c1, %6 : index
%8 = select %7, %c1, %6 : index
%9 = shapex.make_ranked_shape %8 : (index) -> !shapex.ranked_shape<[?]>
%10 = shapex.make_ranked_shape %8 : (index) -> !shapex.ranked_shape<[?,10]>
%11 = muli %8, %c10 : index
%12 = shapex.make_ranked_shape %6 : (index) -> !shapex.ranked_shape<[?,10]>
%13 = muli %6, %c10 : index
%14 = shapex.make_ranked_shape %6 : (index) -> !shapex.ranked_shape<[?,256]>
%15 = muli %6, %c256 : index
%16 = shapex.make_ranked_shape %4 : (index) -> !shapex.ranked_shape<[?,256]>
%17 = muli %4, %c256 : index
%18 = shapex.tie_shape %arg0, %arg1 : tensor<?x784xf32>, !shapex.ranked_shape<[?,784]>
%19 = flow.dispatch.region[%2 : index](%arg2 = %cst_7 : tensor<784x256xf32>, %arg3 = %1 : !shapex.ranked_shape<[?,256]>, %arg4 = %18 : tensor<?x784xf32>, %arg5 = %arg1 : !shapex.ranked_shape<[?,784]>) -> tensor<?x256xf32> {
%95 = shapex.tie_shape %arg4, %arg5 : tensor<?x784xf32>, !shapex.ranked_shape<[?,784]>
%96 = "xla_hlo.dot"(%95, %arg2) : (tensor<?x784xf32>, tensor<784x256xf32>) -> tensor<?x256xf32>
%97 = shapex.tie_shape %96, %arg3 : tensor<?x256xf32>, !shapex.ranked_shape<[?,256]>
flow.return %97 : tensor<?x256xf32>
}
%20 = shapex.tie_shape %19, %1 : tensor<?x256xf32>, !shapex.ranked_shape<[?,256]>
%21 = flow.dispatch.region[%17 : index](%arg2 = %16 : !shapex.ranked_shape<[?,256]>, %arg3 = %20 : tensor<?x256xf32>, %arg4 = %1 : !shapex.ranked_shape<[?,256]>) -> tensor<?x256xf32> {
%95 = shapex.tie_shape %arg3, %arg4 : tensor<?x256xf32>, !shapex.ranked_shape<[?,256]>
%96 = "shapex.ranked_broadcast_in_dim"(%95, %arg2) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<?x256xf32>, !shapex.ranked_shape<[?,256]>) -> tensor<?x256xf32>
%97 = shapex.tie_shape %96, %arg2 : tensor<?x256xf32>, !shapex.ranked_shape<[?,256]>
flow.return %97 : tensor<?x256xf32>
}
%22 = shapex.tie_shape %21, %16 : tensor<?x256xf32>, !shapex.ranked_shape<[?,256]>
%23 = flow.dispatch.region[%17 : index](%arg2 = %cst_4 : tensor<256xf32>, %arg3 = %16 : !shapex.ranked_shape<[?,256]>) -> tensor<?x256xf32> {
%95 = "shapex.ranked_broadcast_in_dim"(%arg2, %arg3) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<256xf32>, !shapex.ranked_shape<[?,256]>) -> tensor<?x256xf32>
%96 = shapex.tie_shape %95, %arg3 : tensor<?x256xf32>, !shapex.ranked_shape<[?,256]>
flow.return %96 : tensor<?x256xf32>
}
%24 = shapex.tie_shape %23, %16 : tensor<?x256xf32>, !shapex.ranked_shape<[?,256]>
%25 = flow.dispatch.region[%17 : index](%arg2 = %16 : !shapex.ranked_shape<[?,256]>, %arg3 = %24 : tensor<?x256xf32>, %arg4 = %22 : tensor<?x256xf32>) -> tensor<?x256xf32> {
%95 = shapex.tie_shape %arg4, %arg2 : tensor<?x256xf32>, !shapex.ranked_shape<[?,256]>
%96 = shapex.tie_shape %arg3, %arg2 : tensor<?x256xf32>, !shapex.ranked_shape<[?,256]>
%97 = xla_hlo.add %95, %96 : tensor<?x256xf32>
%98 = shapex.tie_shape %97, %arg2 : tensor<?x256xf32>, !shapex.ranked_shape<[?,256]>
flow.return %98 : tensor<?x256xf32>
}
%26 = flow.dispatch.region[%17 : index](%arg2 = %cst : tensor<f32>, %arg3 = %16 : !shapex.ranked_shape<[?,256]>) -> tensor<?x256xf32> {
%95 = "shapex.ranked_broadcast_in_dim"(%arg2, %arg3) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<f32>, !shapex.ranked_shape<[?,256]>) -> tensor<?x256xf32>
%96 = shapex.tie_shape %95, %arg3 : tensor<?x256xf32>, !shapex.ranked_shape<[?,256]>
flow.return %96 : tensor<?x256xf32>
}
%27 = shapex.tie_shape %26, %16 : tensor<?x256xf32>, !shapex.ranked_shape<[?,256]>
%28 = shapex.tie_shape %26, %16 : tensor<?x256xf32>, !shapex.ranked_shape<[?,256]>
%29 = shapex.tie_shape %26, %16 : tensor<?x256xf32>, !shapex.ranked_shape<[?,256]>
%30 = shapex.tie_shape %25, %16 : tensor<?x256xf32>, !shapex.ranked_shape<[?,256]>
%31 = flow.dispatch.region[%17 : index](%arg2 = %16 : !shapex.ranked_shape<[?,256]>, %arg3 = %29 : tensor<?x256xf32>, %arg4 = %30 : tensor<?x256xf32>) -> tensor<?x256xf32> {
%95 = shapex.tie_shape %arg4, %arg2 : tensor<?x256xf32>, !shapex.ranked_shape<[?,256]>
%96 = shapex.tie_shape %arg3, %arg2 : tensor<?x256xf32>, !shapex.ranked_shape<[?,256]>
%97 = xla_hlo.multiply %95, %96 : tensor<?x256xf32>
%98 = shapex.tie_shape %97, %arg2 : tensor<?x256xf32>, !shapex.ranked_shape<[?,256]>
flow.return %98 : tensor<?x256xf32>
}
%32 = shapex.tie_shape %31, %16 : tensor<?x256xf32>, !shapex.ranked_shape<[?,256]>
%33 = flow.dispatch.region[%17 : index](%arg2 = %16 : !shapex.ranked_shape<[?,256]>, %arg3 = %32 : tensor<?x256xf32>) -> tensor<?x256xf32> {
%95 = shapex.tie_shape %arg3, %arg2 : tensor<?x256xf32>, !shapex.ranked_shape<[?,256]>
%96 = "xla_hlo.tanh"(%95) : (tensor<?x256xf32>) -> tensor<?x256xf32>
%97 = shapex.tie_shape %96, %arg2 : tensor<?x256xf32>, !shapex.ranked_shape<[?,256]>
flow.return %97 : tensor<?x256xf32>
}
%34 = shapex.tie_shape %33, %16 : tensor<?x256xf32>, !shapex.ranked_shape<[?,256]>
%35 = flow.dispatch.region[%17 : index](%arg2 = %16 : !shapex.ranked_shape<[?,256]>, %arg3 = %34 : tensor<?x256xf32>, %arg4 = %28 : tensor<?x256xf32>) -> tensor<?x256xf32> {
%95 = shapex.tie_shape %arg4, %arg2 : tensor<?x256xf32>, !shapex.ranked_shape<[?,256]>
%96 = shapex.tie_shape %arg3, %arg2 : tensor<?x256xf32>, !shapex.ranked_shape<[?,256]>
%97 = xla_hlo.multiply %96, %95 : tensor<?x256xf32>
%98 = shapex.tie_shape %97, %arg2 : tensor<?x256xf32>, !shapex.ranked_shape<[?,256]>
flow.return %98 : tensor<?x256xf32>
}
%36 = shapex.tie_shape %35, %16 : tensor<?x256xf32>, !shapex.ranked_shape<[?,256]>
%37 = flow.dispatch.region[%17 : index](%arg2 = %16 : !shapex.ranked_shape<[?,256]>, %arg3 = %36 : tensor<?x256xf32>, %arg4 = %27 : tensor<?x256xf32>) -> tensor<?x256xf32> {
%95 = shapex.tie_shape %arg4, %arg2 : tensor<?x256xf32>, !shapex.ranked_shape<[?,256]>
%96 = shapex.tie_shape %arg3, %arg2 : tensor<?x256xf32>, !shapex.ranked_shape<[?,256]>
%97 = xla_hlo.add %96, %95 : tensor<?x256xf32>
%98 = shapex.tie_shape %97, %arg2 : tensor<?x256xf32>, !shapex.ranked_shape<[?,256]>
flow.return %98 : tensor<?x256xf32>
}
%38 = shapex.tie_shape %37, %16 : tensor<?x256xf32>, !shapex.ranked_shape<[?,256]>
%39 = flow.dispatch.region[%17 : index](%arg2 = %cst_5 : tensor<256x256xf32>, %arg3 = %16 : !shapex.ranked_shape<[?,256]>, %arg4 = %38 : tensor<?x256xf32>) -> tensor<?x256xf32> {
%95 = shapex.tie_shape %arg4, %arg3 : tensor<?x256xf32>, !shapex.ranked_shape<[?,256]>
%96 = "xla_hlo.dot"(%95, %arg2) : (tensor<?x256xf32>, tensor<256x256xf32>) -> tensor<?x256xf32>
%97 = shapex.tie_shape %96, %arg3 : tensor<?x256xf32>, !shapex.ranked_shape<[?,256]>
flow.return %97 : tensor<?x256xf32>
}
%40 = shapex.tie_shape %39, %16 : tensor<?x256xf32>, !shapex.ranked_shape<[?,256]>
%41 = flow.dispatch.region[%15 : index](%arg2 = %14 : !shapex.ranked_shape<[?,256]>, %arg3 = %40 : tensor<?x256xf32>, %arg4 = %16 : !shapex.ranked_shape<[?,256]>) -> tensor<?x256xf32> {
%95 = shapex.tie_shape %arg3, %arg4 : tensor<?x256xf32>, !shapex.ranked_shape<[?,256]>
%96 = "shapex.ranked_broadcast_in_dim"(%95, %arg2) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<?x256xf32>, !shapex.ranked_shape<[?,256]>) -> tensor<?x256xf32>
%97 = shapex.tie_shape %96, %arg2 : tensor<?x256xf32>, !shapex.ranked_shape<[?,256]>
flow.return %97 : tensor<?x256xf32>
}
%42 = shapex.tie_shape %41, %14 : tensor<?x256xf32>, !shapex.ranked_shape<[?,256]>
%43 = flow.dispatch.region[%15 : index](%arg2 = %cst_2 : tensor<256xf32>, %arg3 = %14 : !shapex.ranked_shape<[?,256]>) -> tensor<?x256xf32> {
%95 = "shapex.ranked_broadcast_in_dim"(%arg2, %arg3) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<256xf32>, !shapex.ranked_shape<[?,256]>) -> tensor<?x256xf32>
%96 = shapex.tie_shape %95, %arg3 : tensor<?x256xf32>, !shapex.ranked_shape<[?,256]>
flow.return %96 : tensor<?x256xf32>
}
%44 = shapex.tie_shape %43, %14 : tensor<?x256xf32>, !shapex.ranked_shape<[?,256]>
%45 = flow.dispatch.region[%15 : index](%arg2 = %14 : !shapex.ranked_shape<[?,256]>, %arg3 = %44 : tensor<?x256xf32>, %arg4 = %42 : tensor<?x256xf32>) -> tensor<?x256xf32> {
%95 = shapex.tie_shape %arg4, %arg2 : tensor<?x256xf32>, !shapex.ranked_shape<[?,256]>
%96 = shapex.tie_shape %arg3, %arg2 : tensor<?x256xf32>, !shapex.ranked_shape<[?,256]>
%97 = xla_hlo.add %95, %96 : tensor<?x256xf32>
%98 = shapex.tie_shape %97, %arg2 : tensor<?x256xf32>, !shapex.ranked_shape<[?,256]>
flow.return %98 : tensor<?x256xf32>
}
%46 = flow.dispatch.region[%15 : index](%arg2 = %cst : tensor<f32>, %arg3 = %14 : !shapex.ranked_shape<[?,256]>) -> tensor<?x256xf32> {
%95 = "shapex.ranked_broadcast_in_dim"(%arg2, %arg3) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<f32>, !shapex.ranked_shape<[?,256]>) -> tensor<?x256xf32>
%96 = shapex.tie_shape %95, %arg3 : tensor<?x256xf32>, !shapex.ranked_shape<[?,256]>
flow.return %96 : tensor<?x256xf32>
}
%47 = shapex.tie_shape %46, %14 : tensor<?x256xf32>, !shapex.ranked_shape<[?,256]>
%48 = shapex.tie_shape %46, %14 : tensor<?x256xf32>, !shapex.ranked_shape<[?,256]>
%49 = shapex.tie_shape %46, %14 : tensor<?x256xf32>, !shapex.ranked_shape<[?,256]>
%50 = shapex.tie_shape %45, %14 : tensor<?x256xf32>, !shapex.ranked_shape<[?,256]>
%51 = flow.dispatch.region[%15 : index](%arg2 = %14 : !shapex.ranked_shape<[?,256]>, %arg3 = %49 : tensor<?x256xf32>, %arg4 = %50 : tensor<?x256xf32>) -> tensor<?x256xf32> {
%95 = shapex.tie_shape %arg4, %arg2 : tensor<?x256xf32>, !shapex.ranked_shape<[?,256]>
%96 = shapex.tie_shape %arg3, %arg2 : tensor<?x256xf32>, !shapex.ranked_shape<[?,256]>
%97 = xla_hlo.multiply %95, %96 : tensor<?x256xf32>
%98 = shapex.tie_shape %97, %arg2 : tensor<?x256xf32>, !shapex.ranked_shape<[?,256]>
flow.return %98 : tensor<?x256xf32>
}
%52 = shapex.tie_shape %51, %14 : tensor<?x256xf32>, !shapex.ranked_shape<[?,256]>
%53 = flow.dispatch.region[%15 : index](%arg2 = %14 : !shapex.ranked_shape<[?,256]>, %arg3 = %52 : tensor<?x256xf32>) -> tensor<?x256xf32> {
%95 = shapex.tie_shape %arg3, %arg2 : tensor<?x256xf32>, !shapex.ranked_shape<[?,256]>
%96 = "xla_hlo.tanh"(%95) : (tensor<?x256xf32>) -> tensor<?x256xf32>
%97 = shapex.tie_shape %96, %arg2 : tensor<?x256xf32>, !shapex.ranked_shape<[?,256]>
flow.return %97 : tensor<?x256xf32>
}
%54 = shapex.tie_shape %53, %14 : tensor<?x256xf32>, !shapex.ranked_shape<[?,256]>
%55 = flow.dispatch.region[%15 : index](%arg2 = %14 : !shapex.ranked_shape<[?,256]>, %arg3 = %54 : tensor<?x256xf32>, %arg4 = %48 : tensor<?x256xf32>) -> tensor<?x256xf32> {
%95 = shapex.tie_shape %arg4, %arg2 : tensor<?x256xf32>, !shapex.ranked_shape<[?,256]>
%96 = shapex.tie_shape %arg3, %arg2 : tensor<?x256xf32>, !shapex.ranked_shape<[?,256]>
%97 = xla_hlo.multiply %96, %95 : tensor<?x256xf32>
%98 = shapex.tie_shape %97, %arg2 : tensor<?x256xf32>, !shapex.ranked_shape<[?,256]>
flow.return %98 : tensor<?x256xf32>
}
%56 = shapex.tie_shape %55, %14 : tensor<?x256xf32>, !shapex.ranked_shape<[?,256]>
%57 = flow.dispatch.region[%15 : index](%arg2 = %14 : !shapex.ranked_shape<[?,256]>, %arg3 = %56 : tensor<?x256xf32>, %arg4 = %47 : tensor<?x256xf32>) -> tensor<?x256xf32> {
%95 = shapex.tie_shape %arg4, %arg2 : tensor<?x256xf32>, !shapex.ranked_shape<[?,256]>
%96 = shapex.tie_shape %arg3, %arg2 : tensor<?x256xf32>, !shapex.ranked_shape<[?,256]>
%97 = xla_hlo.add %96, %95 : tensor<?x256xf32>
%98 = shapex.tie_shape %97, %arg2 : tensor<?x256xf32>, !shapex.ranked_shape<[?,256]>
flow.return %98 : tensor<?x256xf32>
}
%58 = shapex.tie_shape %57, %14 : tensor<?x256xf32>, !shapex.ranked_shape<[?,256]>
%59 = flow.dispatch.region[%13 : index](%arg2 = %cst_6 : tensor<256x10xf32>, %arg3 = %12 : !shapex.ranked_shape<[?,10]>, %arg4 = %58 : tensor<?x256xf32>, %arg5 = %14 : !shapex.ranked_shape<[?,256]>) -> tensor<?x10xf32> {
%95 = shapex.tie_shape %arg4, %arg5 : tensor<?x256xf32>, !shapex.ranked_shape<[?,256]>
%96 = "xla_hlo.dot"(%95, %arg2) : (tensor<?x256xf32>, tensor<256x10xf32>) -> tensor<?x10xf32>
%97 = shapex.tie_shape %96, %arg3 : tensor<?x10xf32>, !shapex.ranked_shape<[?,10]>
flow.return %97 : tensor<?x10xf32>
}
%60 = shapex.tie_shape %59, %12 : tensor<?x10xf32>, !shapex.ranked_shape<[?,10]>
%61 = flow.dispatch.region[%11 : index](%arg2 = %10 : !shapex.ranked_shape<[?,10]>, %arg3 = %60 : tensor<?x10xf32>, %arg4 = %12 : !shapex.ranked_shape<[?,10]>) -> tensor<?x10xf32> {
%95 = shapex.tie_shape %arg3, %arg4 : tensor<?x10xf32>, !shapex.ranked_shape<[?,10]>
%96 = "shapex.ranked_broadcast_in_dim"(%95, %arg2) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<?x10xf32>, !shapex.ranked_shape<[?,10]>) -> tensor<?x10xf32>
%97 = shapex.tie_shape %96, %arg2 : tensor<?x10xf32>, !shapex.ranked_shape<[?,10]>
flow.return %97 : tensor<?x10xf32>
}
%62 = shapex.tie_shape %61, %10 : tensor<?x10xf32>, !shapex.ranked_shape<[?,10]>
%63 = flow.dispatch.region[%11 : index](%arg2 = %cst_3 : tensor<10xf32>, %arg3 = %10 : !shapex.ranked_shape<[?,10]>) -> tensor<?x10xf32> {
%95 = "shapex.ranked_broadcast_in_dim"(%arg2, %arg3) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<10xf32>, !shapex.ranked_shape<[?,10]>) -> tensor<?x10xf32>
%96 = shapex.tie_shape %95, %arg3 : tensor<?x10xf32>, !shapex.ranked_shape<[?,10]>
flow.return %96 : tensor<?x10xf32>
}
%64 = shapex.tie_shape %63, %10 : tensor<?x10xf32>, !shapex.ranked_shape<[?,10]>
%65 = flow.dispatch.region[%11 : index](%arg2 = %10 : !shapex.ranked_shape<[?,10]>, %arg3 = %64 : tensor<?x10xf32>, %arg4 = %62 : tensor<?x10xf32>) -> tensor<?x10xf32> {
%95 = shapex.tie_shape %arg4, %arg2 : tensor<?x10xf32>, !shapex.ranked_shape<[?,10]>
%96 = shapex.tie_shape %arg3, %arg2 : tensor<?x10xf32>, !shapex.ranked_shape<[?,10]>
%97 = xla_hlo.add %95, %96 : tensor<?x10xf32>
%98 = shapex.tie_shape %97, %arg2 : tensor<?x10xf32>, !shapex.ranked_shape<[?,10]>
flow.return %98 : tensor<?x10xf32>
}
%66 = flow.dispatch.region[%11 : index](%arg2 = %cst : tensor<f32>, %arg3 = %10 : !shapex.ranked_shape<[?,10]>) -> tensor<?x10xf32> {
%95 = "shapex.ranked_broadcast_in_dim"(%arg2, %arg3) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<f32>, !shapex.ranked_shape<[?,10]>) -> tensor<?x10xf32>
%96 = shapex.tie_shape %95, %arg3 : tensor<?x10xf32>, !shapex.ranked_shape<[?,10]>
flow.return %96 : tensor<?x10xf32>
}
%67 = shapex.tie_shape %66, %10 : tensor<?x10xf32>, !shapex.ranked_shape<[?,10]>
%68 = shapex.tie_shape %66, %10 : tensor<?x10xf32>, !shapex.ranked_shape<[?,10]>
%69 = shapex.tie_shape %66, %10 : tensor<?x10xf32>, !shapex.ranked_shape<[?,10]>
%70 = shapex.tie_shape %65, %10 : tensor<?x10xf32>, !shapex.ranked_shape<[?,10]>
%71 = flow.dispatch.region[%11 : index](%arg2 = %10 : !shapex.ranked_shape<[?,10]>, %arg3 = %69 : tensor<?x10xf32>, %arg4 = %70 : tensor<?x10xf32>) -> tensor<?x10xf32> {
%95 = shapex.tie_shape %arg4, %arg2 : tensor<?x10xf32>, !shapex.ranked_shape<[?,10]>
%96 = shapex.tie_shape %arg3, %arg2 : tensor<?x10xf32>, !shapex.ranked_shape<[?,10]>
%97 = xla_hlo.multiply %95, %96 : tensor<?x10xf32>
%98 = shapex.tie_shape %97, %arg2 : tensor<?x10xf32>, !shapex.ranked_shape<[?,10]>
flow.return %98 : tensor<?x10xf32>
}
%72 = shapex.tie_shape %71, %10 : tensor<?x10xf32>, !shapex.ranked_shape<[?,10]>
%73 = flow.dispatch.region[%11 : index](%arg2 = %10 : !shapex.ranked_shape<[?,10]>, %arg3 = %72 : tensor<?x10xf32>) -> tensor<?x10xf32> {
%95 = shapex.tie_shape %arg3, %arg2 : tensor<?x10xf32>, !shapex.ranked_shape<[?,10]>
%96 = "xla_hlo.tanh"(%95) : (tensor<?x10xf32>) -> tensor<?x10xf32>
%97 = shapex.tie_shape %96, %arg2 : tensor<?x10xf32>, !shapex.ranked_shape<[?,10]>
flow.return %97 : tensor<?x10xf32>
}
%74 = shapex.tie_shape %73, %10 : tensor<?x10xf32>, !shapex.ranked_shape<[?,10]>
%75 = flow.dispatch.region[%11 : index](%arg2 = %10 : !shapex.ranked_shape<[?,10]>, %arg3 = %74 : tensor<?x10xf32>, %arg4 = %68 : tensor<?x10xf32>) -> tensor<?x10xf32> {
%95 = shapex.tie_shape %arg4, %arg2 : tensor<?x10xf32>, !shapex.ranked_shape<[?,10]>
%96 = shapex.tie_shape %arg3, %arg2 : tensor<?x10xf32>, !shapex.ranked_shape<[?,10]>
%97 = xla_hlo.multiply %96, %95 : tensor<?x10xf32>
%98 = shapex.tie_shape %97, %arg2 : tensor<?x10xf32>, !shapex.ranked_shape<[?,10]>
flow.return %98 : tensor<?x10xf32>
}
%76 = shapex.tie_shape %75, %10 : tensor<?x10xf32>, !shapex.ranked_shape<[?,10]>
%77 = flow.dispatch.region[%11 : index](%arg2 = %10 : !shapex.ranked_shape<[?,10]>, %arg3 = %76 : tensor<?x10xf32>, %arg4 = %67 : tensor<?x10xf32>) -> tensor<?x10xf32> {
%95 = shapex.tie_shape %arg4, %arg2 : tensor<?x10xf32>, !shapex.ranked_shape<[?,10]>
%96 = shapex.tie_shape %arg3, %arg2 : tensor<?x10xf32>, !shapex.ranked_shape<[?,10]>
%97 = xla_hlo.add %96, %95 : tensor<?x10xf32>
%98 = shapex.tie_shape %97, %arg2 : tensor<?x10xf32>, !shapex.ranked_shape<[?,10]>
flow.return %98 : tensor<?x10xf32>
}
%78 = shapex.tie_shape %77, %10 : tensor<?x10xf32>, !shapex.ranked_shape<[?,10]>
%79 = flow.dispatch.region[%8 : index](%arg2 = %cst_0 : tensor<f32>, %arg3 = %9 : !shapex.ranked_shape<[?]>, %arg4 = %78 : tensor<?x10xf32>, %arg5 = %10 : !shapex.ranked_shape<[?,10]>) -> tensor<?xf32> {
%95 = shapex.tie_shape %arg4, %arg5 : tensor<?x10xf32>, !shapex.ranked_shape<[?,10]>
%96 = "xla_hlo.reduce"(%95, %arg2) ( {
^bb0(%arg6: tensor<f32>, %arg7: tensor<f32>): // no predecessors
%98 = xla_hlo.maximum %arg6, %arg7 : tensor<f32>
"xla_hlo.return"(%98) : (tensor<f32>) -> ()
}) {dimensions = dense<1> : tensor<1xi64>} : (tensor<?x10xf32>, tensor<f32>) -> tensor<?xf32>
%97 = shapex.tie_shape %96, %arg3 : tensor<?xf32>, !shapex.ranked_shape<[?]>
flow.return %97 : tensor<?xf32>
}
%80 = shapex.tie_shape %79, %9 : tensor<?xf32>, !shapex.ranked_shape<[?]>
%81 = flow.dispatch.region[%11 : index](%arg2 = %10 : !shapex.ranked_shape<[?,10]>, %arg3 = %80 : tensor<?xf32>, %arg4 = %9 : !shapex.ranked_shape<[?]>) -> tensor<?x10xf32> {
%95 = shapex.tie_shape %arg3, %arg4 : tensor<?xf32>, !shapex.ranked_shape<[?]>
%96 = "shapex.ranked_broadcast_in_dim"(%95, %arg2) {broadcast_dimensions = dense<0> : tensor<1xi64>} : (tensor<?xf32>, !shapex.ranked_shape<[?,10]>) -> tensor<?x10xf32>
%97 = shapex.tie_shape %96, %arg2 : tensor<?x10xf32>, !shapex.ranked_shape<[?,10]>
flow.return %97 : tensor<?x10xf32>
}
%82 = shapex.tie_shape %81, %10 : tensor<?x10xf32>, !shapex.ranked_shape<[?,10]>
%83 = shapex.tie_shape %77, %10 : tensor<?x10xf32>, !shapex.ranked_shape<[?,10]>
%84 = flow.dispatch.region[%11 : index](%arg2 = %10 : !shapex.ranked_shape<[?,10]>, %arg3 = %82 : tensor<?x10xf32>, %arg4 = %83 : tensor<?x10xf32>) -> tensor<?x10xf32> {
%95 = shapex.tie_shape %arg4, %arg2 : tensor<?x10xf32>, !shapex.ranked_shape<[?,10]>
%96 = shapex.tie_shape %arg3, %arg2 : tensor<?x10xf32>, !shapex.ranked_shape<[?,10]>
%97 = xla_hlo.subtract %95, %96 : tensor<?x10xf32>
%98 = shapex.tie_shape %97, %arg2 : tensor<?x10xf32>, !shapex.ranked_shape<[?,10]>
flow.return %98 : tensor<?x10xf32>
}
%85 = shapex.tie_shape %84, %10 : tensor<?x10xf32>, !shapex.ranked_shape<[?,10]>
%86 = flow.dispatch.region[%11 : index](%arg2 = %10 : !shapex.ranked_shape<[?,10]>, %arg3 = %85 : tensor<?x10xf32>) -> tensor<?x10xf32> {
%95 = shapex.tie_shape %arg3, %arg2 : tensor<?x10xf32>, !shapex.ranked_shape<[?,10]>
%96 = "xla_hlo.exponential"(%95) : (tensor<?x10xf32>) -> tensor<?x10xf32>
%97 = shapex.tie_shape %96, %arg2 : tensor<?x10xf32>, !shapex.ranked_shape<[?,10]>
flow.return %97 : tensor<?x10xf32>
}
%87 = shapex.tie_shape %86, %10 : tensor<?x10xf32>, !shapex.ranked_shape<[?,10]>
%88 = flow.dispatch.region[%8 : index](%arg2 = %cst_1 : tensor<f32>, %arg3 = %9 : !shapex.ranked_shape<[?]>, %arg4 = %87 : tensor<?x10xf32>, %arg5 = %10 : !shapex.ranked_shape<[?,10]>) -> tensor<?xf32> {
%95 = shapex.tie_shape %arg4, %arg5 : tensor<?x10xf32>, !shapex.ranked_shape<[?,10]>
%96 = "xla_hlo.reduce"(%95, %arg2) ( {
^bb0(%arg6: tensor<f32>, %arg7: tensor<f32>): // no predecessors
%98 = xla_hlo.add %arg6, %arg7 : tensor<f32>
"xla_hlo.return"(%98) : (tensor<f32>) -> ()
}) {dimensions = dense<1> : tensor<1xi64>} : (tensor<?x10xf32>, tensor<f32>) -> tensor<?xf32>
%97 = shapex.tie_shape %96, %arg3 : tensor<?xf32>, !shapex.ranked_shape<[?]>
flow.return %97 : tensor<?xf32>
}
%89 = shapex.tie_shape %88, %9 : tensor<?xf32>, !shapex.ranked_shape<[?]>
%90 = flow.dispatch.region[%11 : index](%arg2 = %10 : !shapex.ranked_shape<[?,10]>, %arg3 = %89 : tensor<?xf32>, %arg4 = %9 : !shapex.ranked_shape<[?]>) -> tensor<?x10xf32> {
%95 = shapex.tie_shape %arg3, %arg4 : tensor<?xf32>, !shapex.ranked_shape<[?]>
%96 = "shapex.ranked_broadcast_in_dim"(%95, %arg2) {broadcast_dimensions = dense<0> : tensor<1xi64>} : (tensor<?xf32>, !shapex.ranked_shape<[?,10]>) -> tensor<?x10xf32>
%97 = shapex.tie_shape %96, %arg2 : tensor<?x10xf32>, !shapex.ranked_shape<[?,10]>
flow.return %97 : tensor<?x10xf32>
}
%91 = shapex.tie_shape %90, %10 : tensor<?x10xf32>, !shapex.ranked_shape<[?,10]>
%92 = shapex.tie_shape %86, %10 : tensor<?x10xf32>, !shapex.ranked_shape<[?,10]>
%93 = flow.dispatch.region[%11 : index](%arg2 = %10 : !shapex.ranked_shape<[?,10]>, %arg3 = %91 : tensor<?x10xf32>, %arg4 = %92 : tensor<?x10xf32>) -> tensor<?x10xf32> {
%95 = shapex.tie_shape %arg4, %arg2 : tensor<?x10xf32>, !shapex.ranked_shape<[?,10]>
%96 = shapex.tie_shape %arg3, %arg2 : tensor<?x10xf32>, !shapex.ranked_shape<[?,10]>
%97 = xla_hlo.divide %95, %96 : tensor<?x10xf32>
%98 = shapex.tie_shape %97, %arg2 : tensor<?x10xf32>, !shapex.ranked_shape<[?,10]>
flow.return %98 : tensor<?x10xf32>
}
%94 = shapex.tie_shape %93, %10 : tensor<?x10xf32>, !shapex.ranked_shape<[?,10]>
return %94, %10 : tensor<?x10xf32>, !shapex.ranked_shape<[?,10]>
}In the next step, individual dispatch regions get outlined for individual compilation. At this phase too, we remove all but the root tie_shapes at the inputs, as we are in a form where we have sufficient information in the representation to reconstruct all unknown dims from these roots.
flow.executable @predict_ex_dispatch_9 attributes {sym_visibility = "private"} {
flow.dispatch.entry @predict_ex_dispatch_9
module {
func @predict_ex_dispatch_9(%arg0: tensor<256x256xf32>, %arg1: index, %arg2: tensor<?x256xf32>) -> tensor<?x256xf32> {
%0 = shapex.make_ranked_shape %arg1 : (index) -> !shapex.ranked_shape<[?,256]>
%1 = shapex.tie_shape %arg2, %0 : tensor<?x256xf32>, !shapex.ranked_shape<[?,256]>
%2 = "xla_hlo.dot"(%1, %arg0) : (tensor<?x256xf32>, tensor<256x256xf32>) -> tensor<?x256xf32>
return %2 : tensor<?x256xf32>
}
}
}Note that after outlining and compiling dispatch regions, the tie_shapes in the
original, outer function are the only remaining way to re-discover full
shapes -- since the flow.dispatch calls do not have a way to reify a shape
calculation generally. Example:
%18 = shapex.tie_shape %arg0, %arg1 : tensor<?x784xf32>, !shapex.ranked_shape<[?,784]>
%19 = shapex.ranked_dim %1[0] : !shapex.ranked_shape<[?,256]> -> index
%20 = shapex.ranked_dim %arg1[0] : !shapex.ranked_shape<[?,784]> -> index
%21 = flow.dispatch @predict_ex_dispatch_0::@predict_ex_dispatch_0[%2 : index](%cst_4, %19, %18, %20) : (tensor<784x256xf32>, index, tensor<?x784xf32>, index) -> tensor<?x256xf32>
%22 = shapex.tie_shape %21, %1 : tensor<?x256xf32>, !shapex.ranked_shape<[?,256]>For each dispatch region, interfaces are materialized, with the tie_shapes at the roots. Example:
hal.executable @predict_ex_dispatch_0 attributes {sym_visibility = "private"} {
hal.interface @legacy_io attributes {push_constants = 2 : i32} {
hal.interface.binding @arg0, set=0, binding=0, type="StorageBuffer", access="Read"
hal.interface.binding @arg2, set=0, binding=1, type="StorageBuffer", access="Read"
hal.interface.binding @ret0, set=0, binding=2, type="StorageBuffer", access="Write|Discard"
}
hal.executable.entry_point @predict_ex_dispatch_0 attributes {interface = @legacy_io, ordinal = 0 : i32, signature = (tensor<784x256xf32>, index, tensor<?x784xf32>, index) -> tensor<?x256xf32>}
hal.executable.target "vulkan*" {
module attributes {spv.target_env = #spv.target_env<#spv.vce<v1.3, [Shader], [SPV_KHR_storage_buffer_storage_class]>, {max_compute_workgroup_invocations = 128 : i32, max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>}>} {
func @predict_ex_dispatch_0() {
%c0 = constant 0 : index
%0 = hal.interface.load.constant offset = 0 : index
%1 = hal.interface.load.tensor @legacy_io::@arg0, offset = %c0 : tensor<784x256xf32>
%2 = hal.interface.load.constant offset = 1 : index
%3 = hal.interface.load.tensor @legacy_io::@arg2, offset = %c0 : tensor<?x784xf32>
%4 = call @predict_ex_dispatch_0_impl(%1, %0, %3, %2) : (tensor<784x256xf32>, index, tensor<?x784xf32>, index) -> tensor<?x256xf32>
hal.interface.store.tensor %4, @legacy_io::@ret0, offset = %c0 : tensor<?x256xf32>
return
}
func @predict_ex_dispatch_0_impl(%arg0: tensor<784x256xf32>, %arg1: index, %arg2: tensor<?x784xf32>, %arg3: index) -> tensor<?x256xf32> attributes {sym_visibility = "private"} {
%0 = shapex.make_ranked_shape %arg3 : (index) -> !shapex.ranked_shape<[?,784]>
%1 = shapex.tie_shape %arg2, %0 : tensor<?x784xf32>, !shapex.ranked_shape<[?,784]>
%2 = "xla_hlo.dot"(%1, %arg0) : (tensor<?x784xf32>, tensor<784x256xf32>) -> tensor<?x256xf32>
return %2 : tensor<?x256xf32>
}
hal.interface @legacy_io attributes {push_constants = 2 : i32, sym_visibility = "private"} {
hal.interface.binding @arg0, set=0, binding=0, type="StorageBuffer", access="Read"
hal.interface.binding @arg2, set=0, binding=1, type="StorageBuffer", access="Read"
hal.interface.binding @ret0, set=0, binding=2, type="StorageBuffer", access="Write|Discard"
}
}
}
}For a specific executable, shape calculations are reified:
// *** IR Dump After mlir::iree_compiler::Shape::`anonymous-namespace'::HoistShapeCalculations ***
func @predict_ex_dispatch_0() {
%0 = hal.interface.load.constant offset = 1 : index
%1 = shapex.make_ranked_shape %0 : (index) -> !shapex.ranked_shape<[?,256]>
%2 = shapex.make_ranked_shape %0 : (index) -> !shapex.ranked_shape<[?,784]>
%c0 = constant 0 : index
%3 = hal.interface.load.tensor @legacy_io::@arg0, offset = %c0 : tensor<784x256xf32>
%4 = hal.interface.load.tensor @legacy_io::@arg2, offset = %c0 : tensor<?x784xf32>
%5 = shapex.tie_shape %4, %2 : tensor<?x784xf32>, !shapex.ranked_shape<[?,784]>
%6 = "xla_hlo.dot"(%5, %3) : (tensor<?x784xf32>, tensor<784x256xf32>) -> tensor<?x256xf32>
%7 = shapex.tie_shape %6, %1 : tensor<?x256xf32>, !shapex.ranked_shape<[?,256]>
hal.interface.store.tensor %7, @legacy_io::@ret0, offset = %c0 : tensor<?x256xf32>
return
}And after lowering to buffers:
// *** IR Dump After mlir::iree_compiler::`anonymous-namespace'::ConvertHLOToLinalgOnBuffersPass ***
func @predict_ex_dispatch_0() {
%0 = iree.placeholder for "interface buffer" {binding = @legacy_io::@ret0} : memref<?x256xf32>
%c0 = constant 0 : index
%1 = hal.interface.load.constant offset = 1 : index
%2 = shapex.make_ranked_shape %1 : (index) -> !shapex.ranked_shape<[?,256]>
%3 = shapex.tie_shape %0, %2 : memref<?x256xf32>, !shapex.ranked_shape<[?,256]>
%4 = shapex.make_ranked_shape %1 : (index) -> !shapex.ranked_shape<[?,784]>
%5 = iree.placeholder for "interface buffer" {binding = @legacy_io::@arg0} : memref<784x256xf32>
%6 = iree.placeholder for "interface buffer" {binding = @legacy_io::@arg2} : memref<?x784xf32>
%7 = shapex.tie_shape %6, %4 : memref<?x784xf32>, !shapex.ranked_shape<[?,784]>
%cst = constant 0.000000e+00 : f32
linalg.fill(%3, %cst) : memref<?x256xf32>, f32
linalg.matmul(%7, %5, %3) : memref<?x784xf32>, memref<784x256xf32>, memref<?x256xf32>
return
}After tile and fuse. Note that this is where std.dim gets introduced and
is expected to be resolvable to a root of some kind.
// *** IR Dump After mlir::iree_compiler::`anonymous-namespace'::LinalgTileAndFusePass ***
func @predict_ex_dispatch_0() attributes {spv.entry_point_abi = {local_size = dense<[8, 8, 1]> : vector<3xi32>}} {
%cst = constant 0.000000e+00 : f32
%c0 = constant 0 : index
%c4 = constant 4 : index
%c784 = constant 784 : index
%c8 = constant 8 : index
%c256 = constant 256 : index
%c1 = constant 1 : index
%0 = iree.placeholder for "interface buffer" {binding = @legacy_io::@ret0} : memref<?x256xf32>
%1 = hal.interface.load.constant offset = 1 : index
%2 = shapex.make_ranked_shape %1 : (index) -> !shapex.ranked_shape<[?,256]>
%3 = shapex.tie_shape %0, %2 : memref<?x256xf32>, !shapex.ranked_shape<[?,256]>
%4 = shapex.make_ranked_shape %1 : (index) -> !shapex.ranked_shape<[?,784]>
%5 = iree.placeholder for "interface buffer" {binding = @legacy_io::@arg0} : memref<784x256xf32>
%6 = iree.placeholder for "interface buffer" {binding = @legacy_io::@arg2} : memref<?x784xf32>
%7 = shapex.tie_shape %6, %4 : memref<?x784xf32>, !shapex.ranked_shape<[?,784]>
%8 = dim %3, 0 : memref<?x256xf32>
scf.parallel (%arg0, %arg1) = (%c0, %c0) to (%8, %c256) step (%c8, %c8) {
%10 = dim %3, 0 : memref<?x256xf32>
%11 = affine.min affine_map<(d0, d1, d2) -> (8, d1 - d2)>(%c8, %10, %arg0)
%12 = affine.min affine_map<(d0, d1, d2) -> (8, d1 - d2)>(%c8, %c256, %arg1)
%13 = subview %3[%arg0, %arg1] [%11, %12] [%c1, %c1] : memref<?x256xf32> to memref<?x?xf32, affine_map<(d0, d1)[s0, s1, s2] -> (d0 * s1 + s0 + d1 * s2)>>
linalg.fill(%13, %cst) {__internal_linalg_transform__ = "workitem"} : memref<?x?xf32, affine_map<(d0, d1)[s0, s1, s2] -> (d0 * s1 + s0 + d1 * s2)>>, f32
scf.yield
}
%9 = dim %7, 0 : memref<?x784xf32>
scf.parallel (%arg0, %arg1) = (%c0, %c0) to (%9, %c256) step (%c8, %c8) {
scf.for %arg2 = %c0 to %c784 step %c4 {
%10 = dim %7, 0 : memref<?x784xf32>
%11 = affine.min affine_map<(d0, d1, d2) -> (8, d1 - d2)>(%c8, %10, %arg0)
%12 = affine.min affine_map<(d0, d1, d2) -> (4, d1 - d2)>(%c4, %c784, %arg2)
%13 = subview %7[%arg0, %arg2] [%11, %12] [%c1, %c1] : memref<?x784xf32> to memref<?x?xf32, affine_map<(d0, d1)[s0, s1, s2] -> (d0 * s1 + s0 + d1 * s2)>>
%14 = affine.min affine_map<(d0, d1, d2) -> (4, d1 - d2)>(%c4, %c784, %arg2)
%15 = affine.min affine_map<(d0, d1, d2) -> (8, d1 - d2)>(%c8, %c256, %arg1)
%16 = subview %5[%arg2, %arg1] [%14, %15] [%c1, %c1] : memref<784x256xf32> to memref<?x?xf32, affine_map<(d0, d1)[s0, s1, s2] -> (d0 * s1 + s0 + d1 * s2)>>
%17 = dim %3, 0 : memref<?x256xf32>
%18 = affine.min affine_map<(d0, d1, d2) -> (8, d1 - d2)>(%c8, %17, %arg0)
%19 = affine.min affine_map<(d0, d1, d2) -> (8, d1 - d2)>(%c8, %c256, %arg1)
%20 = subview %3[%arg0, %arg1] [%18, %19] [%c1, %c1] : memref<?x256xf32> to memref<?x?xf32, affine_map<(d0, d1)[s0, s1, s2] -> (d0 * s1 + s0 + d1 * s2)>>
linalg.matmul(%13, %16, %20) {__internal_linalg_transform__ = "workitem"} : memref<?x?xf32, affine_map<(d0, d1)[s0, s1, s2] -> (d0 * s1 + s0 + d1 * s2)>>, memref<?x?xf32, affine_map<(d0, d1)[s0, s1, s2] -> (d0 * s1 + s0 + d1 * s2)>>, memref<?x?xf32, affine_map<(d0, d1)[s0, s1, s2] -> (d0 * s1 + s0 + d1 * s2)>>
}
scf.yield
}
return
}Prior to resolving:
module attributes {spv.target_env = #spv.target_env<#spv.vce<v1.3, [Shader], [SPV_KHR_storage_buffer_storage_class]>, {max_compute_workgroup_invocations = 128 : i32, max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>}>} {
func @predict_ex_dispatch_0() attributes {spv.entry_point_abi = {local_size = dense<[8, 8, 1]> : vector<3xi32>}} {
%cst = constant 0.000000e+00 : f32
%c4 = constant 4 : index
%c784 = constant 784 : index
%c8 = constant 8 : index
%c-1 = constant -1 : index
%c256 = constant 256 : index
%c0 = constant 0 : index
%c1 = constant 1 : index
%0 = iree.placeholder for "interface buffer" {binding = @legacy_io::@ret0} : memref<?x256xf32>
%1 = hal.interface.load.constant offset = 1 : index
%2 = shapex.make_ranked_shape %1 : (index) -> !shapex.ranked_shape<[?,256]>
%3 = shapex.tie_shape %0, %2 : memref<?x256xf32>, !shapex.ranked_shape<[?,256]>
%4 = shapex.make_ranked_shape %1 : (index) -> !shapex.ranked_shape<[?,784]>
%5 = iree.placeholder for "interface buffer" {binding = @legacy_io::@arg0} : memref<784x256xf32>
%6 = iree.placeholder for "interface buffer" {binding = @legacy_io::@arg2} : memref<?x784xf32>
%7 = shapex.tie_shape %6, %4 : memref<?x784xf32>, !shapex.ranked_shape<[?,784]>
%8 = dim %3, 0 : memref<?x256xf32>
%9 = "gpu.block_id"() {dimension = "x"} : () -> index
%10 = "gpu.grid_dim"() {dimension = "x"} : () -> index
%11 = "gpu.block_id"() {dimension = "y"} : () -> index
%12 = "gpu.grid_dim"() {dimension = "y"} : () -> index
%13 = muli %11, %c8 : index
%14 = muli %12, %c8 : index
%15 = muli %9, %c8 : index
%16 = muli %10, %c8 : index
scf.for %arg0 = %13 to %8 step %14 {
scf.for %arg1 = %15 to %c256 step %16 {
%18 = muli %arg0, %c-1 : index
%19 = addi %18, %8 : index
%20 = cmpi "slt", %c8, %19 : index
%21 = select %20, %c8, %19 : index
%22 = muli %arg1, %c-1 : index
%23 = addi %22, %c256 : index
%24 = cmpi "slt", %c8, %23 : index
%25 = select %24, %c8, %23 : index
%26 = subview %3[%arg0, %arg1] [%21, %25] [1, 1] : memref<?x256xf32> to memref<?x?xf32, affine_map<(d0, d1)[s0] -> (d0 * 256 + s0 + d1)>>
%27 = "gpu.thread_id"() {dimension = "x"} : () -> index
%28 = "gpu.block_dim"() {dimension = "x"} : () -> index
%29 = "gpu.thread_id"() {dimension = "y"} : () -> index
%30 = "gpu.block_dim"() {dimension = "y"} : () -> index
scf.for %arg2 = %29 to %21 step %30 {
scf.for %arg3 = %27 to %25 step %28 {
store %cst, %26[%arg2, %arg3] : memref<?x?xf32, affine_map<(d0, d1)[s0] -> (d0 * 256 + s0 + d1)>>
}
}
}
}
%17 = dim %7, 0 : memref<?x784xf32>
scf.for %arg0 = %13 to %17 step %14 {
scf.for %arg1 = %15 to %c256 step %16 {
scf.for %arg2 = %c0 to %c784 step %c4 {
%18 = muli %arg0, %c-1 : index
%19 = addi %18, %17 : index
%20 = cmpi "slt", %c8, %19 : index
%21 = select %20, %c8, %19 : index
%22 = muli %arg2, %c-1 : index
%23 = addi %22, %c784 : index
%24 = cmpi "slt", %c4, %23 : index
%25 = select %24, %c4, %23 : index
%26 = subview %7[%arg0, %arg2] [%21, %25] [1, 1] : memref<?x784xf32> to memref<?x?xf32, affine_map<(d0, d1)[s0] -> (d0 * 784 + s0 + d1)>>
%27 = muli %arg1, %c-1 : index
%28 = addi %27, %c256 : index
%29 = cmpi "slt", %c8, %28 : index
%30 = select %29, %c8, %28 : index
%31 = subview %5[%arg2, %arg1] [%25, %30] [1, 1] : memref<784x256xf32> to memref<?x?xf32, affine_map<(d0, d1)[s0] -> (d0 * 256 + s0 + d1)>>
%32 = addi %18, %8 : index
%33 = cmpi "slt", %c8, %32 : index
%34 = select %33, %c8, %32 : index
%35 = subview %3[%arg0, %arg1] [%34, %30] [1, 1] : memref<?x256xf32> to memref<?x?xf32, affine_map<(d0, d1)[s0] -> (d0 * 256 + s0 + d1)>>
%36 = "gpu.thread_id"() {dimension = "x"} : () -> index
%37 = "gpu.block_dim"() {dimension = "x"} : () -> index
%38 = "gpu.thread_id"() {dimension = "y"} : () -> index
%39 = "gpu.block_dim"() {dimension = "y"} : () -> index
scf.for %arg3 = %38 to %21 step %39 {
scf.for %arg4 = %36 to %30 step %37 {
scf.for %arg5 = %c0 to %25 step %c1 {
%40 = load %31[%arg5, %arg4] : memref<?x?xf32, affine_map<(d0, d1)[s0] -> (d0 * 256 + s0 + d1)>>
%41 = load %26[%arg3, %arg5] : memref<?x?xf32, affine_map<(d0, d1)[s0] -> (d0 * 784 + s0 + d1)>>
%42 = mulf %41, %40 : f32
%43 = load %35[%arg3, %arg4] : memref<?x?xf32, affine_map<(d0, d1)[s0] -> (d0 * 256 + s0 + d1)>>
%44 = addf %43, %42 : f32
store %44, %35[%arg3, %arg4] : memref<?x?xf32, affine_map<(d0, d1)[s0] -> (d0 * 256 + s0 + d1)>>
}
}
}
}
}
}
return
}
hal.interface @legacy_io attributes {push_constants = 2 : i32, sym_visibility = "private"} {
hal.interface.binding @arg0, set=0, binding=0, type="StorageBuffer", access="Read"
hal.interface.binding @arg2, set=0, binding=1, type="StorageBuffer", access="Read"
hal.interface.binding @ret0, set=0, binding=2, type="StorageBuffer", access="Write|Discard"
}
}After std.dim resolving:
// *** IR Dump After mlir::iree_compiler::`anonymous-namespace'::ResolveShapeOpsPass ***
func @predict_ex_dispatch_0() attributes {spv.entry_point_abi = {local_size = dense<[8, 8, 1]> : vector<3xi32>}} {
%cst = constant 0.000000e+00 : f32
%c4 = constant 4 : index
%c784 = constant 784 : index
%c8 = constant 8 : index
%c-1 = constant -1 : index
%c256 = constant 256 : index
%c0 = constant 0 : index
%c1 = constant 1 : index
%0 = iree.placeholder for "interface buffer" {binding = @legacy_io::@ret0} : memref<?x256xf32>
%1 = hal.interface.load.constant offset = 1 : index
%2 = iree.placeholder for "interface buffer" {binding = @legacy_io::@arg0} : memref<784x256xf32>
%3 = iree.placeholder for "interface buffer" {binding = @legacy_io::@arg2} : memref<?x784xf32>
%4 = "gpu.block_id"() {dimension = "x"} : () -> index
%5 = "gpu.grid_dim"() {dimension = "x"} : () -> index
%6 = "gpu.block_id"() {dimension = "y"} : () -> index
%7 = "gpu.grid_dim"() {dimension = "y"} : () -> index
%8 = muli %6, %c8 : index
%9 = muli %7, %c8 : index
%10 = muli %4, %c8 : index
%11 = muli %5, %c8 : index
scf.for %arg0 = %8 to %1 step %9 {
scf.for %arg1 = %10 to %c256 step %11 {
%12 = muli %arg0, %c-1 : index
%13 = addi %12, %1 : index
%14 = cmpi "slt", %c8, %13 : index
%15 = select %14, %c8, %13 : index
%16 = muli %arg1, %c-1 : index
%17 = addi %16, %c256 : index
%18 = cmpi "slt", %c8, %17 : index
%19 = select %18, %c8, %17 : index
%20 = subview %0[%arg0, %arg1] [%15, %19] [1, 1] : memref<?x256xf32> to memref<?x?xf32, affine_map<(d0, d1)[s0] -> (d0 * 256 + s0 + d1)>>
%21 = "gpu.thread_id"() {dimension = "x"} : () -> index
%22 = "gpu.block_dim"() {dimension = "x"} : () -> index
%23 = "gpu.thread_id"() {dimension = "y"} : () -> index
%24 = "gpu.block_dim"() {dimension = "y"} : () -> index
scf.for %arg2 = %23 to %15 step %24 {
scf.for %arg3 = %21 to %19 step %22 {
store %cst, %20[%arg2, %arg3] : memref<?x?xf32, affine_map<(d0, d1)[s0] -> (d0 * 256 + s0 + d1)>>
}
}
}
}
scf.for %arg0 = %8 to %1 step %9 {
scf.for %arg1 = %10 to %c256 step %11 {
scf.for %arg2 = %c0 to %c784 step %c4 {
%12 = muli %arg0, %c-1 : index
%13 = addi %12, %1 : index
%14 = cmpi "slt", %c8, %13 : index
%15 = select %14, %c8, %13 : index
%16 = muli %arg2, %c-1 : index
%17 = addi %16, %c784 : index
%18 = cmpi "slt", %c4, %17 : index
%19 = select %18, %c4, %17 : index
%20 = subview %3[%arg0, %arg2] [%15, %19] [1, 1] : memref<?x784xf32> to memref<?x?xf32, affine_map<(d0, d1)[s0] -> (d0 * 784 + s0 + d1)>>
%21 = muli %arg1, %c-1 : index
%22 = addi %21, %c256 : index
%23 = cmpi "slt", %c8, %22 : index
%24 = select %23, %c8, %22 : index
%25 = subview %2[%arg2, %arg1] [%19, %24] [1, 1] : memref<784x256xf32> to memref<?x?xf32, affine_map<(d0, d1)[s0] -> (d0 * 256 + s0 + d1)>>
%26 = addi %12, %1 : index
%27 = cmpi "slt", %c8, %26 : index
%28 = select %27, %c8, %26 : index
%29 = subview %0[%arg0, %arg1] [%28, %24] [1, 1] : memref<?x256xf32> to memref<?x?xf32, affine_map<(d0, d1)[s0] -> (d0 * 256 + s0 + d1)>>
%30 = "gpu.thread_id"() {dimension = "x"} : () -> index
%31 = "gpu.block_dim"() {dimension = "x"} : () -> index
%32 = "gpu.thread_id"() {dimension = "y"} : () -> index
%33 = "gpu.block_dim"() {dimension = "y"} : () -> index
scf.for %arg3 = %32 to %15 step %33 {
scf.for %arg4 = %30 to %24 step %31 {
scf.for %arg5 = %c0 to %19 step %c1 {
%34 = load %25[%arg5, %arg4] : memref<?x?xf32, affine_map<(d0, d1)[s0] -> (d0 * 256 + s0 + d1)>>
%35 = load %20[%arg3, %arg5] : memref<?x?xf32, affine_map<(d0, d1)[s0] -> (d0 * 784 + s0 + d1)>>
%36 = mulf %35, %34 : f32
%37 = load %29[%arg3, %arg4] : memref<?x?xf32, affine_map<(d0, d1)[s0] -> (d0 * 256 + s0 + d1)>>
%38 = addf %37, %36 : f32
store %38, %29[%arg3, %arg4] : memref<?x?xf32, affine_map<(d0, d1)[s0] -> (d0 * 256 + s0 + d1)>>
}
}
}
}
}
}
return
}Here is an example of a broadcast and add, which for static shapes would
be expected to fuse together. However, the intervening tie_shape makes such
patterns not apply. Note that this was from an incomplete attempt prior to
hacking the above examples to isolate everything at the top level.
module attributes {spv.target_env = #spv.target_env<#spv.vce<v1.3, [Shader], [SPV_KHR_storage_buffer_storage_class]>, {max_compute_workgroup_invocations = 128 : i32, max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>}>} {
func @predict_ex_dispatch_1() {
%c0 = constant 0 : index
%0 = hal.interface.load.constant offset = 1 : index
%1 = shapex.make_ranked_shape %0 : (index) -> !shapex.ranked_shape<[?,256]>
%2 = hal.interface.load.constant offset = 0 : index
%3 = shapex.make_ranked_shape %2 : (index) -> !shapex.ranked_shape<[?,256]>
%4 = hal.interface.load.tensor @legacy_io::@arg1, offset = %c0 : tensor<256xf32>
%5 = hal.interface.load.tensor @legacy_io::@arg2, offset = %c0 : tensor<?x256xf32>
%6 = shapex.tie_shape %5, %1 : tensor<?x256xf32>, !shapex.ranked_shape<[?,256]>
%7 = shapex.tie_shape %6, %3 : tensor<?x256xf32>, !shapex.ranked_shape<[?,256]>
%8 = linalg.generic {args_in = 1 : i64, args_out = 1 : i64, indexing_maps = [affine_map<(d0, d1) -> (d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} %4 {
^bb0(%arg0: f32): // no predecessors
linalg.yield %arg0 : f32
}: tensor<256xf32> -> tensor<?x256xf32>
%9 = shapex.tie_shape %8, %3 : tensor<?x256xf32>, !shapex.ranked_shape<[?,256]>
%10 = linalg.generic {args_in = 2 : i64, args_out = 1 : i64, indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} %7, %9 {
^bb0(%arg0: f32, %arg1: f32): // no predecessors
%12 = addf %arg0, %arg1 : f32
linalg.yield %12 : f32
}: tensor<?x256xf32>, tensor<?x256xf32> -> tensor<?x256xf32>
%11 = shapex.tie_shape %10, %3 : tensor<?x256xf32>, !shapex.ranked_shape<[?,256]>
hal.interface.store.tensor %11, @legacy_io::@ret0, offset = %c0 : tensor<?x256xf32>
return
}
hal.interface @legacy_io attributes {push_constants = 2 : i32, sym_visibility = "private"} {
hal.interface.binding @arg1, set=0, binding=0, type="StorageBuffer", access="Read"
hal.interface.binding @arg2, set=0, binding=1, type="StorageBuffer", access="Read"
hal.interface.binding @ret0, set=0, binding=2, type="StorageBuffer", access="Write|Discard"
}
}This is more of an "I've spent a lot of time on this" vs a discrete example. Since most lowerings involve needing to get the shape of a result to generate code ops that feed into the input, it is easy to end up in situations where the IR is not ordered legally.
If a pass needs to be "shape aware", it must be taught how to walk around the
tie_shape ops. In this example, I've called them "identity metadata" ops and
taught the pass to ensure that they are properly hoisted at the boundaries of
dispatch regions (with bugs I might add: see the canonicalize-cse-canonicalize
cycle above that I am presently working around it with).
I feel like what I actually wish I were reaching for here would be an
analysis that held all shape associations that I could query/update. However,
because this involves actually re-ifying the shapes, there is no place to
"terminate" the SSA values representing shape components, making such an
analysis hard to conceive of. I wish there was a less intrusive way to
"terminate" these reified shapes for a period of time and then drop them
vs needing to teach multiple layers of machinery about tie_shape or
suffering the phase ordering issues of only being able to reason about
shapes once anchored on allocs (and presumed available in a runtime struct).