Skip to content

Instantly share code, notes, and snippets.

@stellaraccident
Created June 2, 2020 00:46
Show Gist options
  • Select an option

  • Save stellaraccident/62eae92e4dfc5a586eb7b744d0a5549e to your computer and use it in GitHub Desktop.

Select an option

Save stellaraccident/62eae92e4dfc5a586eb7b744d0a5549e to your computer and use it in GitHub Desktop.
Issues with tie_shape

Multi-layer perceptron with dynamic batch size IR

func @predict(%arg0: tensor<?x784xf32>) -> tensor<?x10xf32> attributes {iree.module.export, iree.reflection = {abi = "sip", abiv = 1 : i32, sip = "I8!S5!k0_0R3!_0"}, tf._input_shapes = [#tf.shape<?x784>, #tf.shape<*>, #tf.shape<*>, #tf.shape<*>, #tf.shape<*>, #tf.shape<*>, #tf.shape<*>], tf.signature.is_stateful} {
    %0 = flow.variable.address @__iree_flow___sm_node1__h1_weights : !iree.ptr<tensor<784x256xf32>>
    %1 = flow.variable.address @__iree_flow___sm_node4__h1_bias : !iree.ptr<tensor<256xf32>>
    %2 = flow.variable.address @__iree_flow___sm_node2__h2_weights : !iree.ptr<tensor<256x256xf32>>
    %3 = flow.variable.address @__iree_flow___sm_node5__h2_bias : !iree.ptr<tensor<256xf32>>
    %4 = flow.variable.address @__iree_flow___sm_node3__out_weights : !iree.ptr<tensor<256x10xf32>>
    %5 = flow.variable.address @__iree_flow___sm_node6__out_bias : !iree.ptr<tensor<10xf32>>
    %rs256 = shapex.const_ranked_shape : !shapex.ranked_shape<[256]>
    %rs10 = shapex.const_ranked_shape : !shapex.ranked_shape<[10]>
    %6 = xla_hlo.constant dense<5.000000e-01> : tensor<f32>
    %7 = xla_hlo.constant dense<0xFF800000> : tensor<f32>
    %8 = xla_hlo.constant dense<0.000000e+00> : tensor<f32>
    %9 = flow.variable.load.indirect %3 : !iree.ptr<tensor<256xf32>> -> tensor<256xf32>
    %10 = flow.variable.load.indirect %5 : !iree.ptr<tensor<10xf32>> -> tensor<10xf32>
    %11 = flow.variable.load.indirect %1 : !iree.ptr<tensor<256xf32>> -> tensor<256xf32>
    %12 = flow.variable.load.indirect %2 : !iree.ptr<tensor<256x256xf32>> -> tensor<256x256xf32>
    %13 = flow.variable.load.indirect %4 : !iree.ptr<tensor<256x10xf32>> -> tensor<256x10xf32>
    %14 = flow.variable.load.indirect %0 : !iree.ptr<tensor<784x256xf32>> -> tensor<784x256xf32>
    %15 = "xla_hlo.dot"(%arg0, %14) : (tensor<?x784xf32>, tensor<784x256xf32>) -> tensor<?x256xf32>
    %16 = shapex.get_ranked_shape %15 : tensor<?x256xf32> -> !shapex.ranked_shape<[?,256]>
    %17 = "shapex.ranked_broadcast_shape"(%rs256, %16) {lhs_broadcast_dimensions = dense<1> : tensor<1xi64>, rhs_broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (!shapex.ranked_shape<[256]>, !shapex.ranked_shape<[?,256]>) -> !shapex.ranked_shape<[?,256]>
    %18 = "shapex.ranked_broadcast_in_dim"(%15, %17) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<?x256xf32>, !shapex.ranked_shape<[?,256]>) -> tensor<?x256xf32>
    %19 = "shapex.ranked_broadcast_in_dim"(%11, %17) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<256xf32>, !shapex.ranked_shape<[?,256]>) -> tensor<?x256xf32>
    %20 = xla_hlo.add %18, %19 : tensor<?x256xf32>
    %21 = shapex.get_ranked_shape %20 : tensor<?x256xf32> -> !shapex.ranked_shape<[?,256]>
    %22 = "shapex.ranked_broadcast_in_dim"(%6, %21) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<f32>, !shapex.ranked_shape<[?,256]>) -> tensor<?x256xf32>
    %23 = xla_hlo.multiply %20, %22 : tensor<?x256xf32>
    %24 = "xla_hlo.tanh"(%23) : (tensor<?x256xf32>) -> tensor<?x256xf32>
    %25 = xla_hlo.multiply %24, %22 : tensor<?x256xf32>
    %26 = xla_hlo.add %25, %22 : tensor<?x256xf32>
    %27 = "xla_hlo.dot"(%26, %12) : (tensor<?x256xf32>, tensor<256x256xf32>) -> tensor<?x256xf32>
    %28 = shapex.get_ranked_shape %27 : tensor<?x256xf32> -> !shapex.ranked_shape<[?,256]>
    %29 = "shapex.ranked_broadcast_shape"(%rs256, %28) {lhs_broadcast_dimensions = dense<1> : tensor<1xi64>, rhs_broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (!shapex.ranked_shape<[256]>, !shapex.ranked_shape<[?,256]>) -> !shapex.ranked_shape<[?,256]>
    %30 = "shapex.ranked_broadcast_in_dim"(%27, %29) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<?x256xf32>, !shapex.ranked_shape<[?,256]>) -> tensor<?x256xf32>
    %31 = "shapex.ranked_broadcast_in_dim"(%9, %29) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<256xf32>, !shapex.ranked_shape<[?,256]>) -> tensor<?x256xf32>
    %32 = xla_hlo.add %30, %31 : tensor<?x256xf32>
    %33 = shapex.get_ranked_shape %32 : tensor<?x256xf32> -> !shapex.ranked_shape<[?,256]>
    %34 = "shapex.ranked_broadcast_in_dim"(%6, %33) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<f32>, !shapex.ranked_shape<[?,256]>) -> tensor<?x256xf32>
    %35 = xla_hlo.multiply %32, %34 : tensor<?x256xf32>
    %36 = "xla_hlo.tanh"(%35) : (tensor<?x256xf32>) -> tensor<?x256xf32>
    %37 = xla_hlo.multiply %36, %34 : tensor<?x256xf32>
    %38 = xla_hlo.add %37, %34 : tensor<?x256xf32>
    %39 = "xla_hlo.dot"(%38, %13) : (tensor<?x256xf32>, tensor<256x10xf32>) -> tensor<?x10xf32>
    %40 = shapex.get_ranked_shape %39 : tensor<?x10xf32> -> !shapex.ranked_shape<[?,10]>
    %41 = "shapex.ranked_broadcast_shape"(%rs10, %40) {lhs_broadcast_dimensions = dense<1> : tensor<1xi64>, rhs_broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (!shapex.ranked_shape<[10]>, !shapex.ranked_shape<[?,10]>) -> !shapex.ranked_shape<[?,10]>
    %42 = "shapex.ranked_broadcast_in_dim"(%39, %41) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<?x10xf32>, !shapex.ranked_shape<[?,10]>) -> tensor<?x10xf32>
    %43 = "shapex.ranked_broadcast_in_dim"(%10, %41) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<10xf32>, !shapex.ranked_shape<[?,10]>) -> tensor<?x10xf32>
    %44 = xla_hlo.add %42, %43 : tensor<?x10xf32>
    %45 = shapex.get_ranked_shape %44 : tensor<?x10xf32> -> !shapex.ranked_shape<[?,10]>
    %46 = "shapex.ranked_broadcast_in_dim"(%6, %45) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<f32>, !shapex.ranked_shape<[?,10]>) -> tensor<?x10xf32>
    %47 = xla_hlo.multiply %44, %46 : tensor<?x10xf32>
    %48 = "xla_hlo.tanh"(%47) : (tensor<?x10xf32>) -> tensor<?x10xf32>
    %49 = xla_hlo.multiply %48, %46 : tensor<?x10xf32>
    %50 = xla_hlo.add %49, %46 : tensor<?x10xf32>
    %51 = "xla_hlo.reduce"(%50, %7) ( {
    ^bb0(%arg1: tensor<f32>, %arg2: tensor<f32>):  // no predecessors
      %60 = xla_hlo.maximum %arg1, %arg2 : tensor<f32>
      "xla_hlo.return"(%60) : (tensor<f32>) -> ()
    }) {dimensions = dense<1> : tensor<1xi64>} : (tensor<?x10xf32>, tensor<f32>) -> tensor<?xf32>
    %52 = shapex.get_ranked_shape %50 : tensor<?x10xf32> -> !shapex.ranked_shape<[?,10]>
    %53 = "shapex.ranked_broadcast_in_dim"(%51, %52) {broadcast_dimensions = dense<0> : tensor<1xi64>} : (tensor<?xf32>, !shapex.ranked_shape<[?,10]>) -> tensor<?x10xf32>
    %54 = xla_hlo.subtract %50, %53 : tensor<?x10xf32>
    %55 = "xla_hlo.exponential"(%54) : (tensor<?x10xf32>) -> tensor<?x10xf32>
    %56 = "xla_hlo.reduce"(%55, %8) ( {
    ^bb0(%arg1: tensor<f32>, %arg2: tensor<f32>):  // no predecessors
      %60 = xla_hlo.add %arg1, %arg2 : tensor<f32>
      "xla_hlo.return"(%60) : (tensor<f32>) -> ()
    }) {dimensions = dense<1> : tensor<1xi64>} : (tensor<?x10xf32>, tensor<f32>) -> tensor<?xf32>
    %57 = shapex.get_ranked_shape %50 : tensor<?x10xf32> -> !shapex.ranked_shape<[?,10]>
    %58 = "shapex.ranked_broadcast_in_dim"(%56, %57) {broadcast_dimensions = dense<0> : tensor<1xi64>} : (tensor<?xf32>, !shapex.ranked_shape<[?,10]>) -> tensor<?x10xf32>
    %59 = xla_hlo.divide %55, %58 : tensor<?x10xf32>
    return %59 : tensor<?x10xf32>
  }

Hack the pass pipeline to isolate dynamic shaped dispatch regions

Compiles to vulkan-spirv vm module.

iree-opt -mlir-disable-threading -iree-hal-target-backends=vulkan-spirv -pass-pipeline='canonicalize, func(xla-legalize-control-flow, ire e-flow-hlo-to-hlo-preprocessing), convert-shape-to-shapex, iree-flow-flatten-tuples-in-cfg, inline{disable-simplify=false max-iterations=4}, func(canonicalize , cse), iree-flow-legalize-input-types, func(iree-flow-materialize-exported-reflection, iree-shape-expand-function-dynamic-dims, iree-flow-merge-exported-refl ection, iree-shape-tie-dynamic, iree-shape-materialize-calculations, iree-shape-hoist-shape-calculations, iree-flow-pre-partitioning-conversion), iree-flow-di spatchability-analysis, func(iree-flow-identify-dispatch-regions2, canonicalize, cse, canonicalize, iree-flow-rematerialize-dispatch-constants), iree-flow-out line-dispatch-regions, func(canonicalize, iree-flow-post-partitioning-conversion, canonicalize, cse, iree-flow-hoist-unstreamable-ops, canonicalize, iree-flow -form-streams, canonicalize, cse), symbol-dce, canonicalize, iree-hal-materialize-interfaces, hal.executable(iree-hal-translate-executables), iree-hal-link-ex ecutables, iree-convert-flow-to-hal, func(iree-shape-expand-function-ranked-shape-dims, canonicalize, cse), iree-hal-public-abi-generation, iree-hal-materiali ze-resource-caches, func(iree-hal-inline-device-switches), iree-hal-memoize-device-queries, func(canonicalize, cse), hal.executable(iree-hal-serialize-executa bles), symbol-dce, canonicalize, iree-vm-conversion, vm.module(iree-vm-global-initialization), inline{disable-simplify=false max-iterations=4}, cse, symbol-dc e' --mlir-elide-elementsattrs-if-larger=100 ../data/mlp_reproducer.mlir

After MaterializeShapeCalculations pass

// *** IR Dump After mlir::iree_compiler::Shape::`anonymous-namespace'::MaterializeShapeCalculationsPass ***
func @predict(%arg0: tensor<?x784xf32> {iree.reflection = {}}, %arg1: !shapex.ranked_shape<[?,784]> {iree.reflection = {}}) -> (tensor<?x10xf32> {iree.reflection = {}}, !shapex.ranked_shape<[?,10]> {iree.reflection = {}}) attributes {iree.module.export, iree.reflection = {abi = "sip", abiv = 1 : i32, f = "I11!B8!d-1d784R10!B7!d-1d10", fv = "1", sip = "I8!S5!k0_0R3!_0"}, tf._input_shapes = [#tf.shape<?x784>, #tf.shape<*>, #tf.shape<*>, #tf.shape<*>, #tf.shape<*>, #tf.shape<*>, #tf.shape<*>], tf.signature.is_stateful} {
  %0 = xla_hlo.constant dense<5.000000e-01> : tensor<f32>
  %1 = xla_hlo.constant dense<0xFF800000> : tensor<f32>
  %2 = xla_hlo.constant dense<0.000000e+00> : tensor<f32>
  %cst = constant opaque<"", "0xDEADBEEF"> : tensor<256xf32>
  %cst_0 = constant dense<[-1.1745497, -6.830580e-01, -0.22088325, 1.9310919, -0.511352897, -0.557739139, -0.712474167, 0.261824518, 0.724062442, 1.56302774]> : tensor<10xf32>
  %cst_1 = constant opaque<"", "0xDEADBEEF"> : tensor<256xf32>
  %cst_2 = constant opaque<"", "0xDEADBEEF"> : tensor<256x256xf32>
  %cst_3 = constant opaque<"", "0xDEADBEEF"> : tensor<256x10xf32>
  %cst_4 = constant opaque<"", "0xDEADBEEF"> : tensor<784x256xf32>
  %c1 = constant 1 : index
  %3 = shapex.tie_shape %arg0, %arg1 : tensor<?x784xf32>, !shapex.ranked_shape<[?,784]>
  %4 = "xla_hlo.dot"(%3, %cst_4) : (tensor<?x784xf32>, tensor<784x256xf32>) -> tensor<?x256xf32>
  %5 = shapex.ranked_dim %arg1[0] : !shapex.ranked_shape<[?,784]> -> index
  %6 = shapex.make_ranked_shape %5 : (index) -> !shapex.ranked_shape<[?,256]>
  %7 = shapex.tie_shape %4, %6 : tensor<?x256xf32>, !shapex.ranked_shape<[?,256]>
  %8 = cmpi "ugt", %c1, %5 : index
  %9 = select %8, %c1, %5 : index
  %10 = shapex.make_ranked_shape %9 : (index) -> !shapex.ranked_shape<[?,256]>
  %11 = "shapex.ranked_broadcast_in_dim"(%7, %10) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<?x256xf32>, !shapex.ranked_shape<[?,256]>) -> tensor<?x256xf32>
  %12 = shapex.tie_shape %11, %10 : tensor<?x256xf32>, !shapex.ranked_shape<[?,256]>
  %13 = "shapex.ranked_broadcast_in_dim"(%cst_1, %10) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<256xf32>, !shapex.ranked_shape<[?,256]>) -> tensor<?x256xf32>
  %14 = shapex.tie_shape %13, %10 : tensor<?x256xf32>, !shapex.ranked_shape<[?,256]>
  %15 = xla_hlo.add %12, %14 : tensor<?x256xf32>
  %16 = shapex.tie_shape %15, %10 : tensor<?x256xf32>, !shapex.ranked_shape<[?,256]>
  %17 = "shapex.ranked_broadcast_in_dim"(%0, %10) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<f32>, !shapex.ranked_shape<[?,256]>) -> tensor<?x256xf32>
  %18 = shapex.tie_shape %17, %10 : tensor<?x256xf32>, !shapex.ranked_shape<[?,256]>
  %19 = xla_hlo.multiply %16, %18 : tensor<?x256xf32>
  %20 = shapex.tie_shape %19, %10 : tensor<?x256xf32>, !shapex.ranked_shape<[?,256]>
  %21 = "xla_hlo.tanh"(%20) : (tensor<?x256xf32>) -> tensor<?x256xf32>
  %22 = shapex.tie_shape %21, %10 : tensor<?x256xf32>, !shapex.ranked_shape<[?,256]>
  %23 = xla_hlo.multiply %22, %18 : tensor<?x256xf32>
  %24 = shapex.tie_shape %23, %10 : tensor<?x256xf32>, !shapex.ranked_shape<[?,256]>
  %25 = xla_hlo.add %24, %18 : tensor<?x256xf32>
  %26 = shapex.tie_shape %25, %10 : tensor<?x256xf32>, !shapex.ranked_shape<[?,256]>
  %27 = "xla_hlo.dot"(%26, %cst_2) : (tensor<?x256xf32>, tensor<256x256xf32>) -> tensor<?x256xf32>
  %28 = shapex.tie_shape %27, %10 : tensor<?x256xf32>, !shapex.ranked_shape<[?,256]>
  %29 = cmpi "ugt", %c1, %9 : index
  %30 = select %29, %c1, %9 : index
  %31 = shapex.make_ranked_shape %30 : (index) -> !shapex.ranked_shape<[?,256]>
  %32 = "shapex.ranked_broadcast_in_dim"(%28, %31) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<?x256xf32>, !shapex.ranked_shape<[?,256]>) -> tensor<?x256xf32>
  %33 = shapex.tie_shape %32, %31 : tensor<?x256xf32>, !shapex.ranked_shape<[?,256]>
  %34 = "shapex.ranked_broadcast_in_dim"(%cst, %31) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<256xf32>, !shapex.ranked_shape<[?,256]>) -> tensor<?x256xf32>
  %35 = shapex.tie_shape %34, %31 : tensor<?x256xf32>, !shapex.ranked_shape<[?,256]>
  %36 = xla_hlo.add %33, %35 : tensor<?x256xf32>
  %37 = shapex.tie_shape %36, %31 : tensor<?x256xf32>, !shapex.ranked_shape<[?,256]>
  %38 = "shapex.ranked_broadcast_in_dim"(%0, %31) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<f32>, !shapex.ranked_shape<[?,256]>) -> tensor<?x256xf32>
  %39 = shapex.tie_shape %38, %31 : tensor<?x256xf32>, !shapex.ranked_shape<[?,256]>
  %40 = xla_hlo.multiply %37, %39 : tensor<?x256xf32>
  %41 = shapex.tie_shape %40, %31 : tensor<?x256xf32>, !shapex.ranked_shape<[?,256]>
  %42 = "xla_hlo.tanh"(%41) : (tensor<?x256xf32>) -> tensor<?x256xf32>
  %43 = shapex.tie_shape %42, %31 : tensor<?x256xf32>, !shapex.ranked_shape<[?,256]>
  %44 = xla_hlo.multiply %43, %39 : tensor<?x256xf32>
  %45 = shapex.tie_shape %44, %31 : tensor<?x256xf32>, !shapex.ranked_shape<[?,256]>
  %46 = xla_hlo.add %45, %39 : tensor<?x256xf32>
  %47 = shapex.tie_shape %46, %31 : tensor<?x256xf32>, !shapex.ranked_shape<[?,256]>
  %48 = "xla_hlo.dot"(%47, %cst_3) : (tensor<?x256xf32>, tensor<256x10xf32>) -> tensor<?x10xf32>
  %49 = shapex.make_ranked_shape %30 : (index) -> !shapex.ranked_shape<[?,10]>
  %50 = shapex.tie_shape %48, %49 : tensor<?x10xf32>, !shapex.ranked_shape<[?,10]>
  %51 = cmpi "ugt", %c1, %30 : index
  %52 = select %51, %c1, %30 : index
  %53 = shapex.make_ranked_shape %52 : (index) -> !shapex.ranked_shape<[?,10]>
  %54 = "shapex.ranked_broadcast_in_dim"(%50, %53) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<?x10xf32>, !shapex.ranked_shape<[?,10]>) -> tensor<?x10xf32>
  %55 = shapex.tie_shape %54, %53 : tensor<?x10xf32>, !shapex.ranked_shape<[?,10]>
  %56 = "shapex.ranked_broadcast_in_dim"(%cst_0, %53) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<10xf32>, !shapex.ranked_shape<[?,10]>) -> tensor<?x10xf32>
  %57 = shapex.tie_shape %56, %53 : tensor<?x10xf32>, !shapex.ranked_shape<[?,10]>
  %58 = xla_hlo.add %55, %57 : tensor<?x10xf32>
  %59 = shapex.tie_shape %58, %53 : tensor<?x10xf32>, !shapex.ranked_shape<[?,10]>
  %60 = "shapex.ranked_broadcast_in_dim"(%0, %53) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<f32>, !shapex.ranked_shape<[?,10]>) -> tensor<?x10xf32>
  %61 = shapex.tie_shape %60, %53 : tensor<?x10xf32>, !shapex.ranked_shape<[?,10]>
  %62 = xla_hlo.multiply %59, %61 : tensor<?x10xf32>
  %63 = shapex.tie_shape %62, %53 : tensor<?x10xf32>, !shapex.ranked_shape<[?,10]>
  %64 = "xla_hlo.tanh"(%63) : (tensor<?x10xf32>) -> tensor<?x10xf32>
  %65 = shapex.tie_shape %64, %53 : tensor<?x10xf32>, !shapex.ranked_shape<[?,10]>
  %66 = xla_hlo.multiply %65, %61 : tensor<?x10xf32>
  %67 = shapex.tie_shape %66, %53 : tensor<?x10xf32>, !shapex.ranked_shape<[?,10]>
  %68 = xla_hlo.add %67, %61 : tensor<?x10xf32>
  %69 = shapex.tie_shape %68, %53 : tensor<?x10xf32>, !shapex.ranked_shape<[?,10]>
  %70 = "xla_hlo.reduce"(%69, %1) ( {
  ^bb0(%arg2: tensor<f32>, %arg3: tensor<f32>):  // no predecessors
    %86 = xla_hlo.maximum %arg2, %arg3 : tensor<f32>
    "xla_hlo.return"(%86) : (tensor<f32>) -> ()
  }) {dimensions = dense<1> : tensor<1xi64>} : (tensor<?x10xf32>, tensor<f32>) -> tensor<?xf32>
  %71 = shapex.make_ranked_shape %52 : (index) -> !shapex.ranked_shape<[?]>
  %72 = shapex.tie_shape %70, %71 : tensor<?xf32>, !shapex.ranked_shape<[?]>
  %73 = "shapex.ranked_broadcast_in_dim"(%72, %53) {broadcast_dimensions = dense<0> : tensor<1xi64>} : (tensor<?xf32>, !shapex.ranked_shape<[?,10]>) -> tensor<?x10xf32>
  %74 = shapex.tie_shape %73, %53 : tensor<?x10xf32>, !shapex.ranked_shape<[?,10]>
  %75 = xla_hlo.subtract %69, %74 : tensor<?x10xf32>
  %76 = shapex.tie_shape %75, %53 : tensor<?x10xf32>, !shapex.ranked_shape<[?,10]>
  %77 = "xla_hlo.exponential"(%76) : (tensor<?x10xf32>) -> tensor<?x10xf32>
  %78 = shapex.tie_shape %77, %53 : tensor<?x10xf32>, !shapex.ranked_shape<[?,10]>
  %79 = "xla_hlo.reduce"(%78, %2) ( {
  ^bb0(%arg2: tensor<f32>, %arg3: tensor<f32>):  // no predecessors
    %86 = xla_hlo.add %arg2, %arg3 : tensor<f32>
    "xla_hlo.return"(%86) : (tensor<f32>) -> ()
  }) {dimensions = dense<1> : tensor<1xi64>} : (tensor<?x10xf32>, tensor<f32>) -> tensor<?xf32>
  %80 = shapex.make_ranked_shape %52 : (index) -> !shapex.ranked_shape<[?]>
  %81 = shapex.tie_shape %79, %80 : tensor<?xf32>, !shapex.ranked_shape<[?]>
  %82 = "shapex.ranked_broadcast_in_dim"(%81, %53) {broadcast_dimensions = dense<0> : tensor<1xi64>} : (tensor<?xf32>, !shapex.ranked_shape<[?,10]>) -> tensor<?x10xf32>
  %83 = shapex.tie_shape %82, %53 : tensor<?x10xf32>, !shapex.ranked_shape<[?,10]>
  %84 = xla_hlo.divide %78, %83 : tensor<?x10xf32>
  %85 = shapex.tie_shape %84, %53 : tensor<?x10xf32>, !shapex.ranked_shape<[?,10]>
  return %85, %53 : tensor<?x10xf32>, !shapex.ranked_shape<[?,10]>
}

Then hoist

// *** IR Dump After mlir::iree_compiler::Shape::`anonymous-namespace'::HoistShapeCalculations ***
func @predict(%arg0: tensor<?x784xf32> {iree.reflection = {}}, %arg1: !shapex.ranked_shape<[?,784]> {iree.reflection = {}}) -> (tensor<?x10xf32> {iree.reflection = {}}, !shapex.ranked_shape<[?,10]> {iree.reflection = {}}) attributes {iree.module.export, iree.reflection = {abi = "sip", abiv = 1 : i32, f = "I11!B8!d-1d784R10!B7!d-1d10", fv = "1", sip = "I8!S5!k0_0R3!_0"}, tf._input_shapes = [#tf.shape<?x784>, #tf.shape<*>, #tf.shape<*>, #tf.shape<*>, #tf.shape<*>, #tf.shape<*>, #tf.shape<*>], tf.signature.is_stateful} {
  %0 = shapex.ranked_dim %arg1[0] : !shapex.ranked_shape<[?,784]> -> index
  %1 = shapex.make_ranked_shape %0 : (index) -> !shapex.ranked_shape<[?,256]>
  %c1 = constant 1 : index
  %2 = cmpi "ugt", %c1, %0 : index
  %3 = select %2, %c1, %0 : index
  %4 = cmpi "ugt", %c1, %3 : index
  %5 = select %4, %c1, %3 : index
  %6 = cmpi "ugt", %c1, %5 : index
  %7 = select %6, %c1, %5 : index
  %8 = shapex.make_ranked_shape %7 : (index) -> !shapex.ranked_shape<[?]>
  %9 = shapex.make_ranked_shape %7 : (index) -> !shapex.ranked_shape<[?]>
  %10 = shapex.make_ranked_shape %7 : (index) -> !shapex.ranked_shape<[?,10]>
  %11 = shapex.make_ranked_shape %5 : (index) -> !shapex.ranked_shape<[?,10]>
  %12 = shapex.make_ranked_shape %5 : (index) -> !shapex.ranked_shape<[?,256]>
  %13 = shapex.make_ranked_shape %3 : (index) -> !shapex.ranked_shape<[?,256]>
  %14 = xla_hlo.constant dense<5.000000e-01> : tensor<f32>
  %15 = xla_hlo.constant dense<0xFF800000> : tensor<f32>
  %16 = xla_hlo.constant dense<0.000000e+00> : tensor<f32>
  %cst = constant opaque<"", "0xDEADBEEF"> : tensor<256xf32>
  %cst_0 = constant dense<[-1.1745497, -6.830580e-01, -0.22088325, 1.9310919, -0.511352897, -0.557739139, -0.712474167, 0.261824518, 0.724062442, 1.56302774]> : tensor<10xf32>
  %cst_1 = constant opaque<"", "0xDEADBEEF"> : tensor<256xf32>
  %cst_2 = constant opaque<"", "0xDEADBEEF"> : tensor<256x256xf32>
  %cst_3 = constant opaque<"", "0xDEADBEEF"> : tensor<256x10xf32>
  %cst_4 = constant opaque<"", "0xDEADBEEF"> : tensor<784x256xf32>
  %17 = shapex.tie_shape %arg0, %arg1 : tensor<?x784xf32>, !shapex.ranked_shape<[?,784]>
  %18 = "xla_hlo.dot"(%17, %cst_4) : (tensor<?x784xf32>, tensor<784x256xf32>) -> tensor<?x256xf32>
  %19 = shapex.tie_shape %18, %1 : tensor<?x256xf32>, !shapex.ranked_shape<[?,256]>
  %20 = "shapex.ranked_broadcast_in_dim"(%19, %13) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<?x256xf32>, !shapex.ranked_shape<[?,256]>) -> tensor<?x256xf32>
  %21 = shapex.tie_shape %20, %13 : tensor<?x256xf32>, !shapex.ranked_shape<[?,256]>
  %22 = "shapex.ranked_broadcast_in_dim"(%cst_1, %13) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<256xf32>, !shapex.ranked_shape<[?,256]>) -> tensor<?x256xf32>
  %23 = shapex.tie_shape %22, %13 : tensor<?x256xf32>, !shapex.ranked_shape<[?,256]>
  %24 = xla_hlo.add %21, %23 : tensor<?x256xf32>
  %25 = shapex.tie_shape %24, %13 : tensor<?x256xf32>, !shapex.ranked_shape<[?,256]>
  %26 = "shapex.ranked_broadcast_in_dim"(%14, %13) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<f32>, !shapex.ranked_shape<[?,256]>) -> tensor<?x256xf32>
  %27 = shapex.tie_shape %26, %13 : tensor<?x256xf32>, !shapex.ranked_shape<[?,256]>
  %28 = xla_hlo.multiply %25, %27 : tensor<?x256xf32>
  %29 = shapex.tie_shape %28, %13 : tensor<?x256xf32>, !shapex.ranked_shape<[?,256]>
  %30 = "xla_hlo.tanh"(%29) : (tensor<?x256xf32>) -> tensor<?x256xf32>
  %31 = shapex.tie_shape %30, %13 : tensor<?x256xf32>, !shapex.ranked_shape<[?,256]>
  %32 = xla_hlo.multiply %31, %27 : tensor<?x256xf32>
  %33 = shapex.tie_shape %32, %13 : tensor<?x256xf32>, !shapex.ranked_shape<[?,256]>
  %34 = xla_hlo.add %33, %27 : tensor<?x256xf32>
  %35 = shapex.tie_shape %34, %13 : tensor<?x256xf32>, !shapex.ranked_shape<[?,256]>
  %36 = "xla_hlo.dot"(%35, %cst_2) : (tensor<?x256xf32>, tensor<256x256xf32>) -> tensor<?x256xf32>
  %37 = shapex.tie_shape %36, %13 : tensor<?x256xf32>, !shapex.ranked_shape<[?,256]>
  %38 = "shapex.ranked_broadcast_in_dim"(%37, %12) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<?x256xf32>, !shapex.ranked_shape<[?,256]>) -> tensor<?x256xf32>
  %39 = shapex.tie_shape %38, %12 : tensor<?x256xf32>, !shapex.ranked_shape<[?,256]>
  %40 = "shapex.ranked_broadcast_in_dim"(%cst, %12) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<256xf32>, !shapex.ranked_shape<[?,256]>) -> tensor<?x256xf32>
  %41 = shapex.tie_shape %40, %12 : tensor<?x256xf32>, !shapex.ranked_shape<[?,256]>
  %42 = xla_hlo.add %39, %41 : tensor<?x256xf32>
  %43 = shapex.tie_shape %42, %12 : tensor<?x256xf32>, !shapex.ranked_shape<[?,256]>
  %44 = "shapex.ranked_broadcast_in_dim"(%14, %12) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<f32>, !shapex.ranked_shape<[?,256]>) -> tensor<?x256xf32>
  %45 = shapex.tie_shape %44, %12 : tensor<?x256xf32>, !shapex.ranked_shape<[?,256]>
  %46 = xla_hlo.multiply %43, %45 : tensor<?x256xf32>
  %47 = shapex.tie_shape %46, %12 : tensor<?x256xf32>, !shapex.ranked_shape<[?,256]>
  %48 = "xla_hlo.tanh"(%47) : (tensor<?x256xf32>) -> tensor<?x256xf32>
  %49 = shapex.tie_shape %48, %12 : tensor<?x256xf32>, !shapex.ranked_shape<[?,256]>
  %50 = xla_hlo.multiply %49, %45 : tensor<?x256xf32>
  %51 = shapex.tie_shape %50, %12 : tensor<?x256xf32>, !shapex.ranked_shape<[?,256]>
  %52 = xla_hlo.add %51, %45 : tensor<?x256xf32>
  %53 = shapex.tie_shape %52, %12 : tensor<?x256xf32>, !shapex.ranked_shape<[?,256]>
  %54 = "xla_hlo.dot"(%53, %cst_3) : (tensor<?x256xf32>, tensor<256x10xf32>) -> tensor<?x10xf32>
  %55 = shapex.tie_shape %54, %11 : tensor<?x10xf32>, !shapex.ranked_shape<[?,10]>
  %56 = "shapex.ranked_broadcast_in_dim"(%55, %10) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<?x10xf32>, !shapex.ranked_shape<[?,10]>) -> tensor<?x10xf32>
  %57 = shapex.tie_shape %56, %10 : tensor<?x10xf32>, !shapex.ranked_shape<[?,10]>
  %58 = "shapex.ranked_broadcast_in_dim"(%cst_0, %10) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<10xf32>, !shapex.ranked_shape<[?,10]>) -> tensor<?x10xf32>
  %59 = shapex.tie_shape %58, %10 : tensor<?x10xf32>, !shapex.ranked_shape<[?,10]>
  %60 = xla_hlo.add %57, %59 : tensor<?x10xf32>
  %61 = shapex.tie_shape %60, %10 : tensor<?x10xf32>, !shapex.ranked_shape<[?,10]>
  %62 = "shapex.ranked_broadcast_in_dim"(%14, %10) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<f32>, !shapex.ranked_shape<[?,10]>) -> tensor<?x10xf32>
  %63 = shapex.tie_shape %62, %10 : tensor<?x10xf32>, !shapex.ranked_shape<[?,10]>
  %64 = xla_hlo.multiply %61, %63 : tensor<?x10xf32>
  %65 = shapex.tie_shape %64, %10 : tensor<?x10xf32>, !shapex.ranked_shape<[?,10]>
  %66 = "xla_hlo.tanh"(%65) : (tensor<?x10xf32>) -> tensor<?x10xf32>
  %67 = shapex.tie_shape %66, %10 : tensor<?x10xf32>, !shapex.ranked_shape<[?,10]>
  %68 = xla_hlo.multiply %67, %63 : tensor<?x10xf32>
  %69 = shapex.tie_shape %68, %10 : tensor<?x10xf32>, !shapex.ranked_shape<[?,10]>
  %70 = xla_hlo.add %69, %63 : tensor<?x10xf32>
  %71 = shapex.tie_shape %70, %10 : tensor<?x10xf32>, !shapex.ranked_shape<[?,10]>
  %72 = "xla_hlo.reduce"(%71, %15) ( {
  ^bb0(%arg2: tensor<f32>, %arg3: tensor<f32>):  // no predecessors
    %86 = xla_hlo.maximum %arg2, %arg3 : tensor<f32>
    "xla_hlo.return"(%86) : (tensor<f32>) -> ()
  }) {dimensions = dense<1> : tensor<1xi64>} : (tensor<?x10xf32>, tensor<f32>) -> tensor<?xf32>
  %73 = shapex.tie_shape %72, %9 : tensor<?xf32>, !shapex.ranked_shape<[?]>
  %74 = "shapex.ranked_broadcast_in_dim"(%73, %10) {broadcast_dimensions = dense<0> : tensor<1xi64>} : (tensor<?xf32>, !shapex.ranked_shape<[?,10]>) -> tensor<?x10xf32>
  %75 = shapex.tie_shape %74, %10 : tensor<?x10xf32>, !shapex.ranked_shape<[?,10]>
  %76 = xla_hlo.subtract %71, %75 : tensor<?x10xf32>
  %77 = shapex.tie_shape %76, %10 : tensor<?x10xf32>, !shapex.ranked_shape<[?,10]>
  %78 = "xla_hlo.exponential"(%77) : (tensor<?x10xf32>) -> tensor<?x10xf32>
  %79 = shapex.tie_shape %78, %10 : tensor<?x10xf32>, !shapex.ranked_shape<[?,10]>
  %80 = "xla_hlo.reduce"(%79, %16) ( {
  ^bb0(%arg2: tensor<f32>, %arg3: tensor<f32>):  // no predecessors
    %86 = xla_hlo.add %arg2, %arg3 : tensor<f32>
    "xla_hlo.return"(%86) : (tensor<f32>) -> ()
  }) {dimensions = dense<1> : tensor<1xi64>} : (tensor<?x10xf32>, tensor<f32>) -> tensor<?xf32>
  %81 = shapex.tie_shape %80, %8 : tensor<?xf32>, !shapex.ranked_shape<[?]>
  %82 = "shapex.ranked_broadcast_in_dim"(%81, %10) {broadcast_dimensions = dense<0> : tensor<1xi64>} : (tensor<?xf32>, !shapex.ranked_shape<[?,10]>) -> tensor<?x10xf32>
  %83 = shapex.tie_shape %82, %10 : tensor<?x10xf32>, !shapex.ranked_shape<[?,10]>
  %84 = xla_hlo.divide %79, %83 : tensor<?x10xf32>
  %85 = shapex.tie_shape %84, %10 : tensor<?x10xf32>, !shapex.ranked_shape<[?,10]>
  return %85, %10 : tensor<?x10xf32>, !shapex.ranked_shape<[?,10]>
}

Identify dispatch regions

This has a patch in it that disables fusion of anything with dynamic shapes, which works around fusion issues on the backend (see later).

Note that this as after a hacky canonicalize-cse-canonicalize because it is hard to do the tie_shape gymnastics in one go.

The phase ordering issue here related to shapes is that forming the dispatch regions around anchor ("hero" in XLA parlance) ops requires reasoning about the "workload", which further requires reasoning about the shape in terms of calculations operating on actual dimension values. Having access to this as materialized SSA values seems important at this phase.

// *** IR Dump After Canonicalizer ***
func @predict(%arg0: tensor<?x784xf32> {iree.reflection = {}}, %arg1: !shapex.ranked_shape<[?,784]> {iree.reflection = {}}) -> (tensor<?x10xf32> {iree.reflection = {}}, !shapex.ranked_shape<[?,10]> {iree.reflection = {}}) attributes {iree.module.export, iree.reflection = {abi = "sip", abiv = 1 : i32, f = "I11!B8!d-1d784R10!B7!d-1d10", fv = "1", sip = "I8!S5!k0_0R3!_0"}, tf._input_shapes = [#tf.shape<?x784>, #tf.shape<*>, #tf.shape<*>, #tf.shape<*>, #tf.shape<*>, #tf.shape<*>, #tf.shape<*>], tf.signature.is_stateful} {
  %c1 = constant 1 : index
  %c10 = constant 10 : index
  %c256 = constant 256 : index
  %cst = constant dense<5.000000e-01> : tensor<f32>
  %cst_0 = constant dense<0xFF800000> : tensor<f32>
  %cst_1 = constant dense<0.000000e+00> : tensor<f32>
  %cst_2 = constant opaque<"", "0xDEADBEEF"> : tensor<256xf32>
  %cst_3 = constant dense<[-1.1745497, -6.830580e-01, -0.22088325, 1.9310919, -0.511352897, -0.557739139, -0.712474167, 0.261824518, 0.724062442, 1.56302774]> : tensor<10xf32>
  %cst_4 = constant opaque<"", "0xDEADBEEF"> : tensor<256xf32>
  %cst_5 = constant opaque<"", "0xDEADBEEF"> : tensor<256x256xf32>
  %cst_6 = constant opaque<"", "0xDEADBEEF"> : tensor<256x10xf32>
  %cst_7 = constant opaque<"", "0xDEADBEEF"> : tensor<784x256xf32>
  %0 = shapex.ranked_dim %arg1[0] : !shapex.ranked_shape<[?,784]> -> index
  %1 = shapex.make_ranked_shape %0 : (index) -> !shapex.ranked_shape<[?,256]>
  %2 = muli %0, %c256 : index
  %3 = cmpi "ugt", %c1, %0 : index
  %4 = select %3, %c1, %0 : index
  %5 = cmpi "ugt", %c1, %4 : index
  %6 = select %5, %c1, %4 : index
  %7 = cmpi "ugt", %c1, %6 : index
  %8 = select %7, %c1, %6 : index
  %9 = shapex.make_ranked_shape %8 : (index) -> !shapex.ranked_shape<[?]>
  %10 = shapex.make_ranked_shape %8 : (index) -> !shapex.ranked_shape<[?,10]>
  %11 = muli %8, %c10 : index
  %12 = shapex.make_ranked_shape %6 : (index) -> !shapex.ranked_shape<[?,10]>
  %13 = muli %6, %c10 : index
  %14 = shapex.make_ranked_shape %6 : (index) -> !shapex.ranked_shape<[?,256]>
  %15 = muli %6, %c256 : index
  %16 = shapex.make_ranked_shape %4 : (index) -> !shapex.ranked_shape<[?,256]>
  %17 = muli %4, %c256 : index
  %18 = shapex.tie_shape %arg0, %arg1 : tensor<?x784xf32>, !shapex.ranked_shape<[?,784]>
  %19 = flow.dispatch.region[%2 : index](%arg2 = %cst_7 : tensor<784x256xf32>, %arg3 = %1 : !shapex.ranked_shape<[?,256]>, %arg4 = %18 : tensor<?x784xf32>, %arg5 = %arg1 : !shapex.ranked_shape<[?,784]>) -> tensor<?x256xf32> {
    %95 = shapex.tie_shape %arg4, %arg5 : tensor<?x784xf32>, !shapex.ranked_shape<[?,784]>
    %96 = "xla_hlo.dot"(%95, %arg2) : (tensor<?x784xf32>, tensor<784x256xf32>) -> tensor<?x256xf32>
    %97 = shapex.tie_shape %96, %arg3 : tensor<?x256xf32>, !shapex.ranked_shape<[?,256]>
    flow.return %97 : tensor<?x256xf32>
  }
  %20 = shapex.tie_shape %19, %1 : tensor<?x256xf32>, !shapex.ranked_shape<[?,256]>
  %21 = flow.dispatch.region[%17 : index](%arg2 = %16 : !shapex.ranked_shape<[?,256]>, %arg3 = %20 : tensor<?x256xf32>, %arg4 = %1 : !shapex.ranked_shape<[?,256]>) -> tensor<?x256xf32> {
    %95 = shapex.tie_shape %arg3, %arg4 : tensor<?x256xf32>, !shapex.ranked_shape<[?,256]>
    %96 = "shapex.ranked_broadcast_in_dim"(%95, %arg2) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<?x256xf32>, !shapex.ranked_shape<[?,256]>) -> tensor<?x256xf32>
    %97 = shapex.tie_shape %96, %arg2 : tensor<?x256xf32>, !shapex.ranked_shape<[?,256]>
    flow.return %97 : tensor<?x256xf32>
  }
  %22 = shapex.tie_shape %21, %16 : tensor<?x256xf32>, !shapex.ranked_shape<[?,256]>
  %23 = flow.dispatch.region[%17 : index](%arg2 = %cst_4 : tensor<256xf32>, %arg3 = %16 : !shapex.ranked_shape<[?,256]>) -> tensor<?x256xf32> {
    %95 = "shapex.ranked_broadcast_in_dim"(%arg2, %arg3) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<256xf32>, !shapex.ranked_shape<[?,256]>) -> tensor<?x256xf32>
    %96 = shapex.tie_shape %95, %arg3 : tensor<?x256xf32>, !shapex.ranked_shape<[?,256]>
    flow.return %96 : tensor<?x256xf32>
  }
  %24 = shapex.tie_shape %23, %16 : tensor<?x256xf32>, !shapex.ranked_shape<[?,256]>
  %25 = flow.dispatch.region[%17 : index](%arg2 = %16 : !shapex.ranked_shape<[?,256]>, %arg3 = %24 : tensor<?x256xf32>, %arg4 = %22 : tensor<?x256xf32>) -> tensor<?x256xf32> {
    %95 = shapex.tie_shape %arg4, %arg2 : tensor<?x256xf32>, !shapex.ranked_shape<[?,256]>
    %96 = shapex.tie_shape %arg3, %arg2 : tensor<?x256xf32>, !shapex.ranked_shape<[?,256]>
    %97 = xla_hlo.add %95, %96 : tensor<?x256xf32>
    %98 = shapex.tie_shape %97, %arg2 : tensor<?x256xf32>, !shapex.ranked_shape<[?,256]>
    flow.return %98 : tensor<?x256xf32>
  }
  %26 = flow.dispatch.region[%17 : index](%arg2 = %cst : tensor<f32>, %arg3 = %16 : !shapex.ranked_shape<[?,256]>) -> tensor<?x256xf32> {
    %95 = "shapex.ranked_broadcast_in_dim"(%arg2, %arg3) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<f32>, !shapex.ranked_shape<[?,256]>) -> tensor<?x256xf32>
    %96 = shapex.tie_shape %95, %arg3 : tensor<?x256xf32>, !shapex.ranked_shape<[?,256]>
    flow.return %96 : tensor<?x256xf32>
  }
  %27 = shapex.tie_shape %26, %16 : tensor<?x256xf32>, !shapex.ranked_shape<[?,256]>
  %28 = shapex.tie_shape %26, %16 : tensor<?x256xf32>, !shapex.ranked_shape<[?,256]>
  %29 = shapex.tie_shape %26, %16 : tensor<?x256xf32>, !shapex.ranked_shape<[?,256]>
  %30 = shapex.tie_shape %25, %16 : tensor<?x256xf32>, !shapex.ranked_shape<[?,256]>
  %31 = flow.dispatch.region[%17 : index](%arg2 = %16 : !shapex.ranked_shape<[?,256]>, %arg3 = %29 : tensor<?x256xf32>, %arg4 = %30 : tensor<?x256xf32>) -> tensor<?x256xf32> {
    %95 = shapex.tie_shape %arg4, %arg2 : tensor<?x256xf32>, !shapex.ranked_shape<[?,256]>
    %96 = shapex.tie_shape %arg3, %arg2 : tensor<?x256xf32>, !shapex.ranked_shape<[?,256]>
    %97 = xla_hlo.multiply %95, %96 : tensor<?x256xf32>
    %98 = shapex.tie_shape %97, %arg2 : tensor<?x256xf32>, !shapex.ranked_shape<[?,256]>
    flow.return %98 : tensor<?x256xf32>
  }
  %32 = shapex.tie_shape %31, %16 : tensor<?x256xf32>, !shapex.ranked_shape<[?,256]>
  %33 = flow.dispatch.region[%17 : index](%arg2 = %16 : !shapex.ranked_shape<[?,256]>, %arg3 = %32 : tensor<?x256xf32>) -> tensor<?x256xf32> {
    %95 = shapex.tie_shape %arg3, %arg2 : tensor<?x256xf32>, !shapex.ranked_shape<[?,256]>
    %96 = "xla_hlo.tanh"(%95) : (tensor<?x256xf32>) -> tensor<?x256xf32>
    %97 = shapex.tie_shape %96, %arg2 : tensor<?x256xf32>, !shapex.ranked_shape<[?,256]>
    flow.return %97 : tensor<?x256xf32>
  }
  %34 = shapex.tie_shape %33, %16 : tensor<?x256xf32>, !shapex.ranked_shape<[?,256]>
  %35 = flow.dispatch.region[%17 : index](%arg2 = %16 : !shapex.ranked_shape<[?,256]>, %arg3 = %34 : tensor<?x256xf32>, %arg4 = %28 : tensor<?x256xf32>) -> tensor<?x256xf32> {
    %95 = shapex.tie_shape %arg4, %arg2 : tensor<?x256xf32>, !shapex.ranked_shape<[?,256]>
    %96 = shapex.tie_shape %arg3, %arg2 : tensor<?x256xf32>, !shapex.ranked_shape<[?,256]>
    %97 = xla_hlo.multiply %96, %95 : tensor<?x256xf32>
    %98 = shapex.tie_shape %97, %arg2 : tensor<?x256xf32>, !shapex.ranked_shape<[?,256]>
    flow.return %98 : tensor<?x256xf32>
  }
  %36 = shapex.tie_shape %35, %16 : tensor<?x256xf32>, !shapex.ranked_shape<[?,256]>
  %37 = flow.dispatch.region[%17 : index](%arg2 = %16 : !shapex.ranked_shape<[?,256]>, %arg3 = %36 : tensor<?x256xf32>, %arg4 = %27 : tensor<?x256xf32>) -> tensor<?x256xf32> {
    %95 = shapex.tie_shape %arg4, %arg2 : tensor<?x256xf32>, !shapex.ranked_shape<[?,256]>
    %96 = shapex.tie_shape %arg3, %arg2 : tensor<?x256xf32>, !shapex.ranked_shape<[?,256]>
    %97 = xla_hlo.add %96, %95 : tensor<?x256xf32>
    %98 = shapex.tie_shape %97, %arg2 : tensor<?x256xf32>, !shapex.ranked_shape<[?,256]>
    flow.return %98 : tensor<?x256xf32>
  }
  %38 = shapex.tie_shape %37, %16 : tensor<?x256xf32>, !shapex.ranked_shape<[?,256]>
  %39 = flow.dispatch.region[%17 : index](%arg2 = %cst_5 : tensor<256x256xf32>, %arg3 = %16 : !shapex.ranked_shape<[?,256]>, %arg4 = %38 : tensor<?x256xf32>) -> tensor<?x256xf32> {
    %95 = shapex.tie_shape %arg4, %arg3 : tensor<?x256xf32>, !shapex.ranked_shape<[?,256]>
    %96 = "xla_hlo.dot"(%95, %arg2) : (tensor<?x256xf32>, tensor<256x256xf32>) -> tensor<?x256xf32>
    %97 = shapex.tie_shape %96, %arg3 : tensor<?x256xf32>, !shapex.ranked_shape<[?,256]>
    flow.return %97 : tensor<?x256xf32>
  }
  %40 = shapex.tie_shape %39, %16 : tensor<?x256xf32>, !shapex.ranked_shape<[?,256]>
  %41 = flow.dispatch.region[%15 : index](%arg2 = %14 : !shapex.ranked_shape<[?,256]>, %arg3 = %40 : tensor<?x256xf32>, %arg4 = %16 : !shapex.ranked_shape<[?,256]>) -> tensor<?x256xf32> {
    %95 = shapex.tie_shape %arg3, %arg4 : tensor<?x256xf32>, !shapex.ranked_shape<[?,256]>
    %96 = "shapex.ranked_broadcast_in_dim"(%95, %arg2) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<?x256xf32>, !shapex.ranked_shape<[?,256]>) -> tensor<?x256xf32>
    %97 = shapex.tie_shape %96, %arg2 : tensor<?x256xf32>, !shapex.ranked_shape<[?,256]>
    flow.return %97 : tensor<?x256xf32>
  }
  %42 = shapex.tie_shape %41, %14 : tensor<?x256xf32>, !shapex.ranked_shape<[?,256]>
  %43 = flow.dispatch.region[%15 : index](%arg2 = %cst_2 : tensor<256xf32>, %arg3 = %14 : !shapex.ranked_shape<[?,256]>) -> tensor<?x256xf32> {
    %95 = "shapex.ranked_broadcast_in_dim"(%arg2, %arg3) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<256xf32>, !shapex.ranked_shape<[?,256]>) -> tensor<?x256xf32>
    %96 = shapex.tie_shape %95, %arg3 : tensor<?x256xf32>, !shapex.ranked_shape<[?,256]>
    flow.return %96 : tensor<?x256xf32>
  }
  %44 = shapex.tie_shape %43, %14 : tensor<?x256xf32>, !shapex.ranked_shape<[?,256]>
  %45 = flow.dispatch.region[%15 : index](%arg2 = %14 : !shapex.ranked_shape<[?,256]>, %arg3 = %44 : tensor<?x256xf32>, %arg4 = %42 : tensor<?x256xf32>) -> tensor<?x256xf32> {
    %95 = shapex.tie_shape %arg4, %arg2 : tensor<?x256xf32>, !shapex.ranked_shape<[?,256]>
    %96 = shapex.tie_shape %arg3, %arg2 : tensor<?x256xf32>, !shapex.ranked_shape<[?,256]>
    %97 = xla_hlo.add %95, %96 : tensor<?x256xf32>
    %98 = shapex.tie_shape %97, %arg2 : tensor<?x256xf32>, !shapex.ranked_shape<[?,256]>
    flow.return %98 : tensor<?x256xf32>
  }
  %46 = flow.dispatch.region[%15 : index](%arg2 = %cst : tensor<f32>, %arg3 = %14 : !shapex.ranked_shape<[?,256]>) -> tensor<?x256xf32> {
    %95 = "shapex.ranked_broadcast_in_dim"(%arg2, %arg3) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<f32>, !shapex.ranked_shape<[?,256]>) -> tensor<?x256xf32>
    %96 = shapex.tie_shape %95, %arg3 : tensor<?x256xf32>, !shapex.ranked_shape<[?,256]>
    flow.return %96 : tensor<?x256xf32>
  }
  %47 = shapex.tie_shape %46, %14 : tensor<?x256xf32>, !shapex.ranked_shape<[?,256]>
  %48 = shapex.tie_shape %46, %14 : tensor<?x256xf32>, !shapex.ranked_shape<[?,256]>
  %49 = shapex.tie_shape %46, %14 : tensor<?x256xf32>, !shapex.ranked_shape<[?,256]>
  %50 = shapex.tie_shape %45, %14 : tensor<?x256xf32>, !shapex.ranked_shape<[?,256]>
  %51 = flow.dispatch.region[%15 : index](%arg2 = %14 : !shapex.ranked_shape<[?,256]>, %arg3 = %49 : tensor<?x256xf32>, %arg4 = %50 : tensor<?x256xf32>) -> tensor<?x256xf32> {
    %95 = shapex.tie_shape %arg4, %arg2 : tensor<?x256xf32>, !shapex.ranked_shape<[?,256]>
    %96 = shapex.tie_shape %arg3, %arg2 : tensor<?x256xf32>, !shapex.ranked_shape<[?,256]>
    %97 = xla_hlo.multiply %95, %96 : tensor<?x256xf32>
    %98 = shapex.tie_shape %97, %arg2 : tensor<?x256xf32>, !shapex.ranked_shape<[?,256]>
    flow.return %98 : tensor<?x256xf32>
  }
  %52 = shapex.tie_shape %51, %14 : tensor<?x256xf32>, !shapex.ranked_shape<[?,256]>
  %53 = flow.dispatch.region[%15 : index](%arg2 = %14 : !shapex.ranked_shape<[?,256]>, %arg3 = %52 : tensor<?x256xf32>) -> tensor<?x256xf32> {
    %95 = shapex.tie_shape %arg3, %arg2 : tensor<?x256xf32>, !shapex.ranked_shape<[?,256]>
    %96 = "xla_hlo.tanh"(%95) : (tensor<?x256xf32>) -> tensor<?x256xf32>
    %97 = shapex.tie_shape %96, %arg2 : tensor<?x256xf32>, !shapex.ranked_shape<[?,256]>
    flow.return %97 : tensor<?x256xf32>
  }
  %54 = shapex.tie_shape %53, %14 : tensor<?x256xf32>, !shapex.ranked_shape<[?,256]>
  %55 = flow.dispatch.region[%15 : index](%arg2 = %14 : !shapex.ranked_shape<[?,256]>, %arg3 = %54 : tensor<?x256xf32>, %arg4 = %48 : tensor<?x256xf32>) -> tensor<?x256xf32> {
    %95 = shapex.tie_shape %arg4, %arg2 : tensor<?x256xf32>, !shapex.ranked_shape<[?,256]>
    %96 = shapex.tie_shape %arg3, %arg2 : tensor<?x256xf32>, !shapex.ranked_shape<[?,256]>
    %97 = xla_hlo.multiply %96, %95 : tensor<?x256xf32>
    %98 = shapex.tie_shape %97, %arg2 : tensor<?x256xf32>, !shapex.ranked_shape<[?,256]>
    flow.return %98 : tensor<?x256xf32>
  }
  %56 = shapex.tie_shape %55, %14 : tensor<?x256xf32>, !shapex.ranked_shape<[?,256]>
  %57 = flow.dispatch.region[%15 : index](%arg2 = %14 : !shapex.ranked_shape<[?,256]>, %arg3 = %56 : tensor<?x256xf32>, %arg4 = %47 : tensor<?x256xf32>) -> tensor<?x256xf32> {
    %95 = shapex.tie_shape %arg4, %arg2 : tensor<?x256xf32>, !shapex.ranked_shape<[?,256]>
    %96 = shapex.tie_shape %arg3, %arg2 : tensor<?x256xf32>, !shapex.ranked_shape<[?,256]>
    %97 = xla_hlo.add %96, %95 : tensor<?x256xf32>
    %98 = shapex.tie_shape %97, %arg2 : tensor<?x256xf32>, !shapex.ranked_shape<[?,256]>
    flow.return %98 : tensor<?x256xf32>
  }
  %58 = shapex.tie_shape %57, %14 : tensor<?x256xf32>, !shapex.ranked_shape<[?,256]>
  %59 = flow.dispatch.region[%13 : index](%arg2 = %cst_6 : tensor<256x10xf32>, %arg3 = %12 : !shapex.ranked_shape<[?,10]>, %arg4 = %58 : tensor<?x256xf32>, %arg5 = %14 : !shapex.ranked_shape<[?,256]>) -> tensor<?x10xf32> {
    %95 = shapex.tie_shape %arg4, %arg5 : tensor<?x256xf32>, !shapex.ranked_shape<[?,256]>
    %96 = "xla_hlo.dot"(%95, %arg2) : (tensor<?x256xf32>, tensor<256x10xf32>) -> tensor<?x10xf32>
    %97 = shapex.tie_shape %96, %arg3 : tensor<?x10xf32>, !shapex.ranked_shape<[?,10]>
    flow.return %97 : tensor<?x10xf32>
  }
  %60 = shapex.tie_shape %59, %12 : tensor<?x10xf32>, !shapex.ranked_shape<[?,10]>
  %61 = flow.dispatch.region[%11 : index](%arg2 = %10 : !shapex.ranked_shape<[?,10]>, %arg3 = %60 : tensor<?x10xf32>, %arg4 = %12 : !shapex.ranked_shape<[?,10]>) -> tensor<?x10xf32> {
    %95 = shapex.tie_shape %arg3, %arg4 : tensor<?x10xf32>, !shapex.ranked_shape<[?,10]>
    %96 = "shapex.ranked_broadcast_in_dim"(%95, %arg2) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<?x10xf32>, !shapex.ranked_shape<[?,10]>) -> tensor<?x10xf32>
    %97 = shapex.tie_shape %96, %arg2 : tensor<?x10xf32>, !shapex.ranked_shape<[?,10]>
    flow.return %97 : tensor<?x10xf32>
  }
  %62 = shapex.tie_shape %61, %10 : tensor<?x10xf32>, !shapex.ranked_shape<[?,10]>
  %63 = flow.dispatch.region[%11 : index](%arg2 = %cst_3 : tensor<10xf32>, %arg3 = %10 : !shapex.ranked_shape<[?,10]>) -> tensor<?x10xf32> {
    %95 = "shapex.ranked_broadcast_in_dim"(%arg2, %arg3) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<10xf32>, !shapex.ranked_shape<[?,10]>) -> tensor<?x10xf32>
    %96 = shapex.tie_shape %95, %arg3 : tensor<?x10xf32>, !shapex.ranked_shape<[?,10]>
    flow.return %96 : tensor<?x10xf32>
  }
  %64 = shapex.tie_shape %63, %10 : tensor<?x10xf32>, !shapex.ranked_shape<[?,10]>
  %65 = flow.dispatch.region[%11 : index](%arg2 = %10 : !shapex.ranked_shape<[?,10]>, %arg3 = %64 : tensor<?x10xf32>, %arg4 = %62 : tensor<?x10xf32>) -> tensor<?x10xf32> {
    %95 = shapex.tie_shape %arg4, %arg2 : tensor<?x10xf32>, !shapex.ranked_shape<[?,10]>
    %96 = shapex.tie_shape %arg3, %arg2 : tensor<?x10xf32>, !shapex.ranked_shape<[?,10]>
    %97 = xla_hlo.add %95, %96 : tensor<?x10xf32>
    %98 = shapex.tie_shape %97, %arg2 : tensor<?x10xf32>, !shapex.ranked_shape<[?,10]>
    flow.return %98 : tensor<?x10xf32>
  }
  %66 = flow.dispatch.region[%11 : index](%arg2 = %cst : tensor<f32>, %arg3 = %10 : !shapex.ranked_shape<[?,10]>) -> tensor<?x10xf32> {
    %95 = "shapex.ranked_broadcast_in_dim"(%arg2, %arg3) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<f32>, !shapex.ranked_shape<[?,10]>) -> tensor<?x10xf32>
    %96 = shapex.tie_shape %95, %arg3 : tensor<?x10xf32>, !shapex.ranked_shape<[?,10]>
    flow.return %96 : tensor<?x10xf32>
  }
  %67 = shapex.tie_shape %66, %10 : tensor<?x10xf32>, !shapex.ranked_shape<[?,10]>
  %68 = shapex.tie_shape %66, %10 : tensor<?x10xf32>, !shapex.ranked_shape<[?,10]>
  %69 = shapex.tie_shape %66, %10 : tensor<?x10xf32>, !shapex.ranked_shape<[?,10]>
  %70 = shapex.tie_shape %65, %10 : tensor<?x10xf32>, !shapex.ranked_shape<[?,10]>
  %71 = flow.dispatch.region[%11 : index](%arg2 = %10 : !shapex.ranked_shape<[?,10]>, %arg3 = %69 : tensor<?x10xf32>, %arg4 = %70 : tensor<?x10xf32>) -> tensor<?x10xf32> {
    %95 = shapex.tie_shape %arg4, %arg2 : tensor<?x10xf32>, !shapex.ranked_shape<[?,10]>
    %96 = shapex.tie_shape %arg3, %arg2 : tensor<?x10xf32>, !shapex.ranked_shape<[?,10]>
    %97 = xla_hlo.multiply %95, %96 : tensor<?x10xf32>
    %98 = shapex.tie_shape %97, %arg2 : tensor<?x10xf32>, !shapex.ranked_shape<[?,10]>
    flow.return %98 : tensor<?x10xf32>
  }
  %72 = shapex.tie_shape %71, %10 : tensor<?x10xf32>, !shapex.ranked_shape<[?,10]>
  %73 = flow.dispatch.region[%11 : index](%arg2 = %10 : !shapex.ranked_shape<[?,10]>, %arg3 = %72 : tensor<?x10xf32>) -> tensor<?x10xf32> {
    %95 = shapex.tie_shape %arg3, %arg2 : tensor<?x10xf32>, !shapex.ranked_shape<[?,10]>
    %96 = "xla_hlo.tanh"(%95) : (tensor<?x10xf32>) -> tensor<?x10xf32>
    %97 = shapex.tie_shape %96, %arg2 : tensor<?x10xf32>, !shapex.ranked_shape<[?,10]>
    flow.return %97 : tensor<?x10xf32>
  }
  %74 = shapex.tie_shape %73, %10 : tensor<?x10xf32>, !shapex.ranked_shape<[?,10]>
  %75 = flow.dispatch.region[%11 : index](%arg2 = %10 : !shapex.ranked_shape<[?,10]>, %arg3 = %74 : tensor<?x10xf32>, %arg4 = %68 : tensor<?x10xf32>) -> tensor<?x10xf32> {
    %95 = shapex.tie_shape %arg4, %arg2 : tensor<?x10xf32>, !shapex.ranked_shape<[?,10]>
    %96 = shapex.tie_shape %arg3, %arg2 : tensor<?x10xf32>, !shapex.ranked_shape<[?,10]>
    %97 = xla_hlo.multiply %96, %95 : tensor<?x10xf32>
    %98 = shapex.tie_shape %97, %arg2 : tensor<?x10xf32>, !shapex.ranked_shape<[?,10]>
    flow.return %98 : tensor<?x10xf32>
  }
  %76 = shapex.tie_shape %75, %10 : tensor<?x10xf32>, !shapex.ranked_shape<[?,10]>
  %77 = flow.dispatch.region[%11 : index](%arg2 = %10 : !shapex.ranked_shape<[?,10]>, %arg3 = %76 : tensor<?x10xf32>, %arg4 = %67 : tensor<?x10xf32>) -> tensor<?x10xf32> {
    %95 = shapex.tie_shape %arg4, %arg2 : tensor<?x10xf32>, !shapex.ranked_shape<[?,10]>
    %96 = shapex.tie_shape %arg3, %arg2 : tensor<?x10xf32>, !shapex.ranked_shape<[?,10]>
    %97 = xla_hlo.add %96, %95 : tensor<?x10xf32>
    %98 = shapex.tie_shape %97, %arg2 : tensor<?x10xf32>, !shapex.ranked_shape<[?,10]>
    flow.return %98 : tensor<?x10xf32>
  }
  %78 = shapex.tie_shape %77, %10 : tensor<?x10xf32>, !shapex.ranked_shape<[?,10]>
  %79 = flow.dispatch.region[%8 : index](%arg2 = %cst_0 : tensor<f32>, %arg3 = %9 : !shapex.ranked_shape<[?]>, %arg4 = %78 : tensor<?x10xf32>, %arg5 = %10 : !shapex.ranked_shape<[?,10]>) -> tensor<?xf32> {
    %95 = shapex.tie_shape %arg4, %arg5 : tensor<?x10xf32>, !shapex.ranked_shape<[?,10]>
    %96 = "xla_hlo.reduce"(%95, %arg2) ( {
    ^bb0(%arg6: tensor<f32>, %arg7: tensor<f32>):  // no predecessors
      %98 = xla_hlo.maximum %arg6, %arg7 : tensor<f32>
      "xla_hlo.return"(%98) : (tensor<f32>) -> ()
    }) {dimensions = dense<1> : tensor<1xi64>} : (tensor<?x10xf32>, tensor<f32>) -> tensor<?xf32>
    %97 = shapex.tie_shape %96, %arg3 : tensor<?xf32>, !shapex.ranked_shape<[?]>
    flow.return %97 : tensor<?xf32>
  }
  %80 = shapex.tie_shape %79, %9 : tensor<?xf32>, !shapex.ranked_shape<[?]>
  %81 = flow.dispatch.region[%11 : index](%arg2 = %10 : !shapex.ranked_shape<[?,10]>, %arg3 = %80 : tensor<?xf32>, %arg4 = %9 : !shapex.ranked_shape<[?]>) -> tensor<?x10xf32> {
    %95 = shapex.tie_shape %arg3, %arg4 : tensor<?xf32>, !shapex.ranked_shape<[?]>
    %96 = "shapex.ranked_broadcast_in_dim"(%95, %arg2) {broadcast_dimensions = dense<0> : tensor<1xi64>} : (tensor<?xf32>, !shapex.ranked_shape<[?,10]>) -> tensor<?x10xf32>
    %97 = shapex.tie_shape %96, %arg2 : tensor<?x10xf32>, !shapex.ranked_shape<[?,10]>
    flow.return %97 : tensor<?x10xf32>
  }
  %82 = shapex.tie_shape %81, %10 : tensor<?x10xf32>, !shapex.ranked_shape<[?,10]>
  %83 = shapex.tie_shape %77, %10 : tensor<?x10xf32>, !shapex.ranked_shape<[?,10]>
  %84 = flow.dispatch.region[%11 : index](%arg2 = %10 : !shapex.ranked_shape<[?,10]>, %arg3 = %82 : tensor<?x10xf32>, %arg4 = %83 : tensor<?x10xf32>) -> tensor<?x10xf32> {
    %95 = shapex.tie_shape %arg4, %arg2 : tensor<?x10xf32>, !shapex.ranked_shape<[?,10]>
    %96 = shapex.tie_shape %arg3, %arg2 : tensor<?x10xf32>, !shapex.ranked_shape<[?,10]>
    %97 = xla_hlo.subtract %95, %96 : tensor<?x10xf32>
    %98 = shapex.tie_shape %97, %arg2 : tensor<?x10xf32>, !shapex.ranked_shape<[?,10]>
    flow.return %98 : tensor<?x10xf32>
  }
  %85 = shapex.tie_shape %84, %10 : tensor<?x10xf32>, !shapex.ranked_shape<[?,10]>
  %86 = flow.dispatch.region[%11 : index](%arg2 = %10 : !shapex.ranked_shape<[?,10]>, %arg3 = %85 : tensor<?x10xf32>) -> tensor<?x10xf32> {
    %95 = shapex.tie_shape %arg3, %arg2 : tensor<?x10xf32>, !shapex.ranked_shape<[?,10]>
    %96 = "xla_hlo.exponential"(%95) : (tensor<?x10xf32>) -> tensor<?x10xf32>
    %97 = shapex.tie_shape %96, %arg2 : tensor<?x10xf32>, !shapex.ranked_shape<[?,10]>
    flow.return %97 : tensor<?x10xf32>
  }
  %87 = shapex.tie_shape %86, %10 : tensor<?x10xf32>, !shapex.ranked_shape<[?,10]>
  %88 = flow.dispatch.region[%8 : index](%arg2 = %cst_1 : tensor<f32>, %arg3 = %9 : !shapex.ranked_shape<[?]>, %arg4 = %87 : tensor<?x10xf32>, %arg5 = %10 : !shapex.ranked_shape<[?,10]>) -> tensor<?xf32> {
    %95 = shapex.tie_shape %arg4, %arg5 : tensor<?x10xf32>, !shapex.ranked_shape<[?,10]>
    %96 = "xla_hlo.reduce"(%95, %arg2) ( {
    ^bb0(%arg6: tensor<f32>, %arg7: tensor<f32>):  // no predecessors
      %98 = xla_hlo.add %arg6, %arg7 : tensor<f32>
      "xla_hlo.return"(%98) : (tensor<f32>) -> ()
    }) {dimensions = dense<1> : tensor<1xi64>} : (tensor<?x10xf32>, tensor<f32>) -> tensor<?xf32>
    %97 = shapex.tie_shape %96, %arg3 : tensor<?xf32>, !shapex.ranked_shape<[?]>
    flow.return %97 : tensor<?xf32>
  }
  %89 = shapex.tie_shape %88, %9 : tensor<?xf32>, !shapex.ranked_shape<[?]>
  %90 = flow.dispatch.region[%11 : index](%arg2 = %10 : !shapex.ranked_shape<[?,10]>, %arg3 = %89 : tensor<?xf32>, %arg4 = %9 : !shapex.ranked_shape<[?]>) -> tensor<?x10xf32> {
    %95 = shapex.tie_shape %arg3, %arg4 : tensor<?xf32>, !shapex.ranked_shape<[?]>
    %96 = "shapex.ranked_broadcast_in_dim"(%95, %arg2) {broadcast_dimensions = dense<0> : tensor<1xi64>} : (tensor<?xf32>, !shapex.ranked_shape<[?,10]>) -> tensor<?x10xf32>
    %97 = shapex.tie_shape %96, %arg2 : tensor<?x10xf32>, !shapex.ranked_shape<[?,10]>
    flow.return %97 : tensor<?x10xf32>
  }
  %91 = shapex.tie_shape %90, %10 : tensor<?x10xf32>, !shapex.ranked_shape<[?,10]>
  %92 = shapex.tie_shape %86, %10 : tensor<?x10xf32>, !shapex.ranked_shape<[?,10]>
  %93 = flow.dispatch.region[%11 : index](%arg2 = %10 : !shapex.ranked_shape<[?,10]>, %arg3 = %91 : tensor<?x10xf32>, %arg4 = %92 : tensor<?x10xf32>) -> tensor<?x10xf32> {
    %95 = shapex.tie_shape %arg4, %arg2 : tensor<?x10xf32>, !shapex.ranked_shape<[?,10]>
    %96 = shapex.tie_shape %arg3, %arg2 : tensor<?x10xf32>, !shapex.ranked_shape<[?,10]>
    %97 = xla_hlo.divide %95, %96 : tensor<?x10xf32>
    %98 = shapex.tie_shape %97, %arg2 : tensor<?x10xf32>, !shapex.ranked_shape<[?,10]>
    flow.return %98 : tensor<?x10xf32>
  }
  %94 = shapex.tie_shape %93, %10 : tensor<?x10xf32>, !shapex.ranked_shape<[?,10]>
  return %94, %10 : tensor<?x10xf32>, !shapex.ranked_shape<[?,10]>
}

In the next step, individual dispatch regions get outlined for individual compilation. At this phase too, we remove all but the root tie_shapes at the inputs, as we are in a form where we have sufficient information in the representation to reconstruct all unknown dims from these roots.

  flow.executable @predict_ex_dispatch_9 attributes {sym_visibility = "private"} {
    flow.dispatch.entry @predict_ex_dispatch_9
    module {
      func @predict_ex_dispatch_9(%arg0: tensor<256x256xf32>, %arg1: index, %arg2: tensor<?x256xf32>) -> tensor<?x256xf32> {
        %0 = shapex.make_ranked_shape %arg1 : (index) -> !shapex.ranked_shape<[?,256]>
        %1 = shapex.tie_shape %arg2, %0 : tensor<?x256xf32>, !shapex.ranked_shape<[?,256]>
        %2 = "xla_hlo.dot"(%1, %arg0) : (tensor<?x256xf32>, tensor<256x256xf32>) -> tensor<?x256xf32>
        return %2 : tensor<?x256xf32>
      }
    }
  }

Note that after outlining and compiling dispatch regions, the tie_shapes in the original, outer function are the only remaining way to re-discover full shapes -- since the flow.dispatch calls do not have a way to reify a shape calculation generally. Example:

    %18 = shapex.tie_shape %arg0, %arg1 : tensor<?x784xf32>, !shapex.ranked_shape<[?,784]>
    %19 = shapex.ranked_dim %1[0] : !shapex.ranked_shape<[?,256]> -> index
    %20 = shapex.ranked_dim %arg1[0] : !shapex.ranked_shape<[?,784]> -> index
    %21 = flow.dispatch @predict_ex_dispatch_0::@predict_ex_dispatch_0[%2 : index](%cst_4, %19, %18, %20) : (tensor<784x256xf32>, index, tensor<?x784xf32>, index) -> tensor<?x256xf32>
    %22 = shapex.tie_shape %21, %1 : tensor<?x256xf32>, !shapex.ranked_shape<[?,256]>

Backend compilation

For each dispatch region, interfaces are materialized, with the tie_shapes at the roots. Example:

  hal.executable @predict_ex_dispatch_0 attributes {sym_visibility = "private"} {
    hal.interface @legacy_io attributes {push_constants = 2 : i32} {
      hal.interface.binding @arg0, set=0, binding=0, type="StorageBuffer", access="Read"
      hal.interface.binding @arg2, set=0, binding=1, type="StorageBuffer", access="Read"
      hal.interface.binding @ret0, set=0, binding=2, type="StorageBuffer", access="Write|Discard"
    }
    hal.executable.entry_point @predict_ex_dispatch_0 attributes {interface = @legacy_io, ordinal = 0 : i32, signature = (tensor<784x256xf32>, index, tensor<?x784xf32>, index) -> tensor<?x256xf32>}
    hal.executable.target "vulkan*" {
      module attributes {spv.target_env = #spv.target_env<#spv.vce<v1.3, [Shader], [SPV_KHR_storage_buffer_storage_class]>, {max_compute_workgroup_invocations = 128 : i32, max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>}>} {
        func @predict_ex_dispatch_0() {
          %c0 = constant 0 : index
          %0 = hal.interface.load.constant offset = 0 : index
          %1 = hal.interface.load.tensor @legacy_io::@arg0, offset = %c0 : tensor<784x256xf32>
          %2 = hal.interface.load.constant offset = 1 : index
          %3 = hal.interface.load.tensor @legacy_io::@arg2, offset = %c0 : tensor<?x784xf32>
          %4 = call @predict_ex_dispatch_0_impl(%1, %0, %3, %2) : (tensor<784x256xf32>, index, tensor<?x784xf32>, index) -> tensor<?x256xf32>
          hal.interface.store.tensor %4, @legacy_io::@ret0, offset = %c0 : tensor<?x256xf32>
          return
        }
        func @predict_ex_dispatch_0_impl(%arg0: tensor<784x256xf32>, %arg1: index, %arg2: tensor<?x784xf32>, %arg3: index) -> tensor<?x256xf32> attributes {sym_visibility = "private"} {
          %0 = shapex.make_ranked_shape %arg3 : (index) -> !shapex.ranked_shape<[?,784]>
          %1 = shapex.tie_shape %arg2, %0 : tensor<?x784xf32>, !shapex.ranked_shape<[?,784]>
          %2 = "xla_hlo.dot"(%1, %arg0) : (tensor<?x784xf32>, tensor<784x256xf32>) -> tensor<?x256xf32>
          return %2 : tensor<?x256xf32>
        }
        hal.interface @legacy_io attributes {push_constants = 2 : i32, sym_visibility = "private"} {
          hal.interface.binding @arg0, set=0, binding=0, type="StorageBuffer", access="Read"
          hal.interface.binding @arg2, set=0, binding=1, type="StorageBuffer", access="Read"
          hal.interface.binding @ret0, set=0, binding=2, type="StorageBuffer", access="Write|Discard"
        }
      }
    }
  }

For a specific executable, shape calculations are reified:

// *** IR Dump After mlir::iree_compiler::Shape::`anonymous-namespace'::HoistShapeCalculations ***
func @predict_ex_dispatch_0() {
  %0 = hal.interface.load.constant offset = 1 : index
  %1 = shapex.make_ranked_shape %0 : (index) -> !shapex.ranked_shape<[?,256]>
  %2 = shapex.make_ranked_shape %0 : (index) -> !shapex.ranked_shape<[?,784]>
  %c0 = constant 0 : index
  %3 = hal.interface.load.tensor @legacy_io::@arg0, offset = %c0 : tensor<784x256xf32>
  %4 = hal.interface.load.tensor @legacy_io::@arg2, offset = %c0 : tensor<?x784xf32>
  %5 = shapex.tie_shape %4, %2 : tensor<?x784xf32>, !shapex.ranked_shape<[?,784]>
  %6 = "xla_hlo.dot"(%5, %3) : (tensor<?x784xf32>, tensor<784x256xf32>) -> tensor<?x256xf32>
  %7 = shapex.tie_shape %6, %1 : tensor<?x256xf32>, !shapex.ranked_shape<[?,256]>
  hal.interface.store.tensor %7, @legacy_io::@ret0, offset = %c0 : tensor<?x256xf32>
  return
}

And after lowering to buffers:

// *** IR Dump After mlir::iree_compiler::`anonymous-namespace'::ConvertHLOToLinalgOnBuffersPass ***
func @predict_ex_dispatch_0() {
  %0 = iree.placeholder for "interface buffer" {binding = @legacy_io::@ret0} : memref<?x256xf32>
  %c0 = constant 0 : index
  %1 = hal.interface.load.constant offset = 1 : index
  %2 = shapex.make_ranked_shape %1 : (index) -> !shapex.ranked_shape<[?,256]>
  %3 = shapex.tie_shape %0, %2 : memref<?x256xf32>, !shapex.ranked_shape<[?,256]>
  %4 = shapex.make_ranked_shape %1 : (index) -> !shapex.ranked_shape<[?,784]>
  %5 = iree.placeholder for "interface buffer" {binding = @legacy_io::@arg0} : memref<784x256xf32>
  %6 = iree.placeholder for "interface buffer" {binding = @legacy_io::@arg2} : memref<?x784xf32>
  %7 = shapex.tie_shape %6, %4 : memref<?x784xf32>, !shapex.ranked_shape<[?,784]>
  %cst = constant 0.000000e+00 : f32
  linalg.fill(%3, %cst) : memref<?x256xf32>, f32
  linalg.matmul(%7, %5, %3) : memref<?x784xf32>, memref<784x256xf32>, memref<?x256xf32>
  return
}

After tile and fuse. Note that this is where std.dim gets introduced and is expected to be resolvable to a root of some kind.

// *** IR Dump After mlir::iree_compiler::`anonymous-namespace'::LinalgTileAndFusePass ***
func @predict_ex_dispatch_0() attributes {spv.entry_point_abi = {local_size = dense<[8, 8, 1]> : vector<3xi32>}} {
  %cst = constant 0.000000e+00 : f32
  %c0 = constant 0 : index
  %c4 = constant 4 : index
  %c784 = constant 784 : index
  %c8 = constant 8 : index
  %c256 = constant 256 : index
  %c1 = constant 1 : index
  %0 = iree.placeholder for "interface buffer" {binding = @legacy_io::@ret0} : memref<?x256xf32>
  %1 = hal.interface.load.constant offset = 1 : index
  %2 = shapex.make_ranked_shape %1 : (index) -> !shapex.ranked_shape<[?,256]>
  %3 = shapex.tie_shape %0, %2 : memref<?x256xf32>, !shapex.ranked_shape<[?,256]>
  %4 = shapex.make_ranked_shape %1 : (index) -> !shapex.ranked_shape<[?,784]>
  %5 = iree.placeholder for "interface buffer" {binding = @legacy_io::@arg0} : memref<784x256xf32>
  %6 = iree.placeholder for "interface buffer" {binding = @legacy_io::@arg2} : memref<?x784xf32>
  %7 = shapex.tie_shape %6, %4 : memref<?x784xf32>, !shapex.ranked_shape<[?,784]>
  %8 = dim %3, 0 : memref<?x256xf32>
  scf.parallel (%arg0, %arg1) = (%c0, %c0) to (%8, %c256) step (%c8, %c8) {
    %10 = dim %3, 0 : memref<?x256xf32>
    %11 = affine.min affine_map<(d0, d1, d2) -> (8, d1 - d2)>(%c8, %10, %arg0)
    %12 = affine.min affine_map<(d0, d1, d2) -> (8, d1 - d2)>(%c8, %c256, %arg1)
    %13 = subview %3[%arg0, %arg1] [%11, %12] [%c1, %c1]  : memref<?x256xf32> to memref<?x?xf32, affine_map<(d0, d1)[s0, s1, s2] -> (d0 * s1 + s0 + d1 * s2)>>
    linalg.fill(%13, %cst) {__internal_linalg_transform__ = "workitem"} : memref<?x?xf32, affine_map<(d0, d1)[s0, s1, s2] -> (d0 * s1 + s0 + d1 * s2)>>, f32
    scf.yield
  }
  %9 = dim %7, 0 : memref<?x784xf32>
  scf.parallel (%arg0, %arg1) = (%c0, %c0) to (%9, %c256) step (%c8, %c8) {
    scf.for %arg2 = %c0 to %c784 step %c4 {
      %10 = dim %7, 0 : memref<?x784xf32>
      %11 = affine.min affine_map<(d0, d1, d2) -> (8, d1 - d2)>(%c8, %10, %arg0)
      %12 = affine.min affine_map<(d0, d1, d2) -> (4, d1 - d2)>(%c4, %c784, %arg2)
      %13 = subview %7[%arg0, %arg2] [%11, %12] [%c1, %c1]  : memref<?x784xf32> to memref<?x?xf32, affine_map<(d0, d1)[s0, s1, s2] -> (d0 * s1 + s0 + d1 * s2)>>
      %14 = affine.min affine_map<(d0, d1, d2) -> (4, d1 - d2)>(%c4, %c784, %arg2)
      %15 = affine.min affine_map<(d0, d1, d2) -> (8, d1 - d2)>(%c8, %c256, %arg1)
      %16 = subview %5[%arg2, %arg1] [%14, %15] [%c1, %c1]  : memref<784x256xf32> to memref<?x?xf32, affine_map<(d0, d1)[s0, s1, s2] -> (d0 * s1 + s0 + d1 * s2)>>
      %17 = dim %3, 0 : memref<?x256xf32>
      %18 = affine.min affine_map<(d0, d1, d2) -> (8, d1 - d2)>(%c8, %17, %arg0)
      %19 = affine.min affine_map<(d0, d1, d2) -> (8, d1 - d2)>(%c8, %c256, %arg1)
      %20 = subview %3[%arg0, %arg1] [%18, %19] [%c1, %c1]  : memref<?x256xf32> to memref<?x?xf32, affine_map<(d0, d1)[s0, s1, s2] -> (d0 * s1 + s0 + d1 * s2)>>
      linalg.matmul(%13, %16, %20) {__internal_linalg_transform__ = "workitem"} : memref<?x?xf32, affine_map<(d0, d1)[s0, s1, s2] -> (d0 * s1 + s0 + d1 * s2)>>, memref<?x?xf32, affine_map<(d0, d1)[s0, s1, s2] -> (d0 * s1 + s0 + d1 * s2)>>, memref<?x?xf32, affine_map<(d0, d1)[s0, s1, s2] -> (d0 * s1 + s0 + d1 * s2)>>
    }
    scf.yield
  }
  return
}

Prior to resolving:

module attributes {spv.target_env = #spv.target_env<#spv.vce<v1.3, [Shader], [SPV_KHR_storage_buffer_storage_class]>, {max_compute_workgroup_invocations = 128 : i32, max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>}>} {
  func @predict_ex_dispatch_0() attributes {spv.entry_point_abi = {local_size = dense<[8, 8, 1]> : vector<3xi32>}} {
    %cst = constant 0.000000e+00 : f32
    %c4 = constant 4 : index
    %c784 = constant 784 : index
    %c8 = constant 8 : index
    %c-1 = constant -1 : index
    %c256 = constant 256 : index
    %c0 = constant 0 : index
    %c1 = constant 1 : index
    %0 = iree.placeholder for "interface buffer" {binding = @legacy_io::@ret0} : memref<?x256xf32>
    %1 = hal.interface.load.constant offset = 1 : index
    %2 = shapex.make_ranked_shape %1 : (index) -> !shapex.ranked_shape<[?,256]>
    %3 = shapex.tie_shape %0, %2 : memref<?x256xf32>, !shapex.ranked_shape<[?,256]>
    %4 = shapex.make_ranked_shape %1 : (index) -> !shapex.ranked_shape<[?,784]>
    %5 = iree.placeholder for "interface buffer" {binding = @legacy_io::@arg0} : memref<784x256xf32>
    %6 = iree.placeholder for "interface buffer" {binding = @legacy_io::@arg2} : memref<?x784xf32>
    %7 = shapex.tie_shape %6, %4 : memref<?x784xf32>, !shapex.ranked_shape<[?,784]>
    %8 = dim %3, 0 : memref<?x256xf32>
    %9 = "gpu.block_id"() {dimension = "x"} : () -> index
    %10 = "gpu.grid_dim"() {dimension = "x"} : () -> index
    %11 = "gpu.block_id"() {dimension = "y"} : () -> index
    %12 = "gpu.grid_dim"() {dimension = "y"} : () -> index
    %13 = muli %11, %c8 : index
    %14 = muli %12, %c8 : index
    %15 = muli %9, %c8 : index
    %16 = muli %10, %c8 : index
    scf.for %arg0 = %13 to %8 step %14 {
      scf.for %arg1 = %15 to %c256 step %16 {
        %18 = muli %arg0, %c-1 : index
        %19 = addi %18, %8 : index
        %20 = cmpi "slt", %c8, %19 : index
        %21 = select %20, %c8, %19 : index
        %22 = muli %arg1, %c-1 : index
        %23 = addi %22, %c256 : index
        %24 = cmpi "slt", %c8, %23 : index
        %25 = select %24, %c8, %23 : index
        %26 = subview %3[%arg0, %arg1] [%21, %25] [1, 1]  : memref<?x256xf32> to memref<?x?xf32, affine_map<(d0, d1)[s0] -> (d0 * 256 + s0 + d1)>>
        %27 = "gpu.thread_id"() {dimension = "x"} : () -> index
        %28 = "gpu.block_dim"() {dimension = "x"} : () -> index
        %29 = "gpu.thread_id"() {dimension = "y"} : () -> index
        %30 = "gpu.block_dim"() {dimension = "y"} : () -> index
        scf.for %arg2 = %29 to %21 step %30 {
          scf.for %arg3 = %27 to %25 step %28 {
            store %cst, %26[%arg2, %arg3] : memref<?x?xf32, affine_map<(d0, d1)[s0] -> (d0 * 256 + s0 + d1)>>
          }
        }
      }
    }
    %17 = dim %7, 0 : memref<?x784xf32>
    scf.for %arg0 = %13 to %17 step %14 {
      scf.for %arg1 = %15 to %c256 step %16 {
        scf.for %arg2 = %c0 to %c784 step %c4 {
          %18 = muli %arg0, %c-1 : index
          %19 = addi %18, %17 : index
          %20 = cmpi "slt", %c8, %19 : index
          %21 = select %20, %c8, %19 : index
          %22 = muli %arg2, %c-1 : index
          %23 = addi %22, %c784 : index
          %24 = cmpi "slt", %c4, %23 : index
          %25 = select %24, %c4, %23 : index
          %26 = subview %7[%arg0, %arg2] [%21, %25] [1, 1]  : memref<?x784xf32> to memref<?x?xf32, affine_map<(d0, d1)[s0] -> (d0 * 784 + s0 + d1)>>
          %27 = muli %arg1, %c-1 : index
          %28 = addi %27, %c256 : index
          %29 = cmpi "slt", %c8, %28 : index
          %30 = select %29, %c8, %28 : index
          %31 = subview %5[%arg2, %arg1] [%25, %30] [1, 1]  : memref<784x256xf32> to memref<?x?xf32, affine_map<(d0, d1)[s0] -> (d0 * 256 + s0 + d1)>>
          %32 = addi %18, %8 : index
          %33 = cmpi "slt", %c8, %32 : index
          %34 = select %33, %c8, %32 : index
          %35 = subview %3[%arg0, %arg1] [%34, %30] [1, 1]  : memref<?x256xf32> to memref<?x?xf32, affine_map<(d0, d1)[s0] -> (d0 * 256 + s0 + d1)>>
          %36 = "gpu.thread_id"() {dimension = "x"} : () -> index
          %37 = "gpu.block_dim"() {dimension = "x"} : () -> index
          %38 = "gpu.thread_id"() {dimension = "y"} : () -> index
          %39 = "gpu.block_dim"() {dimension = "y"} : () -> index
          scf.for %arg3 = %38 to %21 step %39 {
            scf.for %arg4 = %36 to %30 step %37 {
              scf.for %arg5 = %c0 to %25 step %c1 {
                %40 = load %31[%arg5, %arg4] : memref<?x?xf32, affine_map<(d0, d1)[s0] -> (d0 * 256 + s0 + d1)>>
                %41 = load %26[%arg3, %arg5] : memref<?x?xf32, affine_map<(d0, d1)[s0] -> (d0 * 784 + s0 + d1)>>
                %42 = mulf %41, %40 : f32
                %43 = load %35[%arg3, %arg4] : memref<?x?xf32, affine_map<(d0, d1)[s0] -> (d0 * 256 + s0 + d1)>>
                %44 = addf %43, %42 : f32
                store %44, %35[%arg3, %arg4] : memref<?x?xf32, affine_map<(d0, d1)[s0] -> (d0 * 256 + s0 + d1)>>
              }
            }
          }
        }
      }
    }
    return
  }
  hal.interface @legacy_io attributes {push_constants = 2 : i32, sym_visibility = "private"} {
    hal.interface.binding @arg0, set=0, binding=0, type="StorageBuffer", access="Read"
    hal.interface.binding @arg2, set=0, binding=1, type="StorageBuffer", access="Read"
    hal.interface.binding @ret0, set=0, binding=2, type="StorageBuffer", access="Write|Discard"
  }
}

After std.dim resolving:

// *** IR Dump After mlir::iree_compiler::`anonymous-namespace'::ResolveShapeOpsPass ***
func @predict_ex_dispatch_0() attributes {spv.entry_point_abi = {local_size = dense<[8, 8, 1]> : vector<3xi32>}} {
  %cst = constant 0.000000e+00 : f32
  %c4 = constant 4 : index
  %c784 = constant 784 : index
  %c8 = constant 8 : index
  %c-1 = constant -1 : index
  %c256 = constant 256 : index
  %c0 = constant 0 : index
  %c1 = constant 1 : index
  %0 = iree.placeholder for "interface buffer" {binding = @legacy_io::@ret0} : memref<?x256xf32>
  %1 = hal.interface.load.constant offset = 1 : index
  %2 = iree.placeholder for "interface buffer" {binding = @legacy_io::@arg0} : memref<784x256xf32>
  %3 = iree.placeholder for "interface buffer" {binding = @legacy_io::@arg2} : memref<?x784xf32>
  %4 = "gpu.block_id"() {dimension = "x"} : () -> index
  %5 = "gpu.grid_dim"() {dimension = "x"} : () -> index
  %6 = "gpu.block_id"() {dimension = "y"} : () -> index
  %7 = "gpu.grid_dim"() {dimension = "y"} : () -> index
  %8 = muli %6, %c8 : index
  %9 = muli %7, %c8 : index
  %10 = muli %4, %c8 : index
  %11 = muli %5, %c8 : index
  scf.for %arg0 = %8 to %1 step %9 {
    scf.for %arg1 = %10 to %c256 step %11 {
      %12 = muli %arg0, %c-1 : index
      %13 = addi %12, %1 : index
      %14 = cmpi "slt", %c8, %13 : index
      %15 = select %14, %c8, %13 : index
      %16 = muli %arg1, %c-1 : index
      %17 = addi %16, %c256 : index
      %18 = cmpi "slt", %c8, %17 : index
      %19 = select %18, %c8, %17 : index
      %20 = subview %0[%arg0, %arg1] [%15, %19] [1, 1]  : memref<?x256xf32> to memref<?x?xf32, affine_map<(d0, d1)[s0] -> (d0 * 256 + s0 + d1)>>
      %21 = "gpu.thread_id"() {dimension = "x"} : () -> index
      %22 = "gpu.block_dim"() {dimension = "x"} : () -> index
      %23 = "gpu.thread_id"() {dimension = "y"} : () -> index
      %24 = "gpu.block_dim"() {dimension = "y"} : () -> index
      scf.for %arg2 = %23 to %15 step %24 {
        scf.for %arg3 = %21 to %19 step %22 {
          store %cst, %20[%arg2, %arg3] : memref<?x?xf32, affine_map<(d0, d1)[s0] -> (d0 * 256 + s0 + d1)>>
        }
      }
    }
  }
  scf.for %arg0 = %8 to %1 step %9 {
    scf.for %arg1 = %10 to %c256 step %11 {
      scf.for %arg2 = %c0 to %c784 step %c4 {
        %12 = muli %arg0, %c-1 : index
        %13 = addi %12, %1 : index
        %14 = cmpi "slt", %c8, %13 : index
        %15 = select %14, %c8, %13 : index
        %16 = muli %arg2, %c-1 : index
        %17 = addi %16, %c784 : index
        %18 = cmpi "slt", %c4, %17 : index
        %19 = select %18, %c4, %17 : index
        %20 = subview %3[%arg0, %arg2] [%15, %19] [1, 1]  : memref<?x784xf32> to memref<?x?xf32, affine_map<(d0, d1)[s0] -> (d0 * 784 + s0 + d1)>>
        %21 = muli %arg1, %c-1 : index
        %22 = addi %21, %c256 : index
        %23 = cmpi "slt", %c8, %22 : index
        %24 = select %23, %c8, %22 : index
        %25 = subview %2[%arg2, %arg1] [%19, %24] [1, 1]  : memref<784x256xf32> to memref<?x?xf32, affine_map<(d0, d1)[s0] -> (d0 * 256 + s0 + d1)>>
        %26 = addi %12, %1 : index
        %27 = cmpi "slt", %c8, %26 : index
        %28 = select %27, %c8, %26 : index
        %29 = subview %0[%arg0, %arg1] [%28, %24] [1, 1]  : memref<?x256xf32> to memref<?x?xf32, affine_map<(d0, d1)[s0] -> (d0 * 256 + s0 + d1)>>
        %30 = "gpu.thread_id"() {dimension = "x"} : () -> index
        %31 = "gpu.block_dim"() {dimension = "x"} : () -> index
        %32 = "gpu.thread_id"() {dimension = "y"} : () -> index
        %33 = "gpu.block_dim"() {dimension = "y"} : () -> index
        scf.for %arg3 = %32 to %15 step %33 {
          scf.for %arg4 = %30 to %24 step %31 {
            scf.for %arg5 = %c0 to %19 step %c1 {
              %34 = load %25[%arg5, %arg4] : memref<?x?xf32, affine_map<(d0, d1)[s0] -> (d0 * 256 + s0 + d1)>>
              %35 = load %20[%arg3, %arg5] : memref<?x?xf32, affine_map<(d0, d1)[s0] -> (d0 * 784 + s0 + d1)>>
              %36 = mulf %35, %34 : f32
              %37 = load %29[%arg3, %arg4] : memref<?x?xf32, affine_map<(d0, d1)[s0] -> (d0 * 256 + s0 + d1)>>
              %38 = addf %37, %36 : f32
              store %38, %29[%arg3, %arg4] : memref<?x?xf32, affine_map<(d0, d1)[s0] -> (d0 * 256 + s0 + d1)>>
            }
          }
        }
      }
    }
  }
  return
}

Issues

Invasive to fusion

Here is an example of a broadcast and add, which for static shapes would be expected to fuse together. However, the intervening tie_shape makes such patterns not apply. Note that this was from an incomplete attempt prior to hacking the above examples to isolate everything at the top level.

module attributes {spv.target_env = #spv.target_env<#spv.vce<v1.3, [Shader], [SPV_KHR_storage_buffer_storage_class]>, {max_compute_workgroup_invocations = 128 : i32, max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>}>} {
  func @predict_ex_dispatch_1() {
    %c0 = constant 0 : index
    %0 = hal.interface.load.constant offset = 1 : index
    %1 = shapex.make_ranked_shape %0 : (index) -> !shapex.ranked_shape<[?,256]>
    %2 = hal.interface.load.constant offset = 0 : index
    %3 = shapex.make_ranked_shape %2 : (index) -> !shapex.ranked_shape<[?,256]>
    %4 = hal.interface.load.tensor @legacy_io::@arg1, offset = %c0 : tensor<256xf32>
    %5 = hal.interface.load.tensor @legacy_io::@arg2, offset = %c0 : tensor<?x256xf32>
    %6 = shapex.tie_shape %5, %1 : tensor<?x256xf32>, !shapex.ranked_shape<[?,256]>
    %7 = shapex.tie_shape %6, %3 : tensor<?x256xf32>, !shapex.ranked_shape<[?,256]>
    %8 = linalg.generic {args_in = 1 : i64, args_out = 1 : i64, indexing_maps = [affine_map<(d0, d1) -> (d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} %4 {
    ^bb0(%arg0: f32):  // no predecessors
      linalg.yield %arg0 : f32
    }: tensor<256xf32> -> tensor<?x256xf32>
    %9 = shapex.tie_shape %8, %3 : tensor<?x256xf32>, !shapex.ranked_shape<[?,256]>
    %10 = linalg.generic {args_in = 2 : i64, args_out = 1 : i64, indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} %7, %9 {
    ^bb0(%arg0: f32, %arg1: f32):  // no predecessors
      %12 = addf %arg0, %arg1 : f32
      linalg.yield %12 : f32
    }: tensor<?x256xf32>, tensor<?x256xf32> -> tensor<?x256xf32>
    %11 = shapex.tie_shape %10, %3 : tensor<?x256xf32>, !shapex.ranked_shape<[?,256]>
    hal.interface.store.tensor %11, @legacy_io::@ret0, offset = %c0 : tensor<?x256xf32>
    return
  }
  hal.interface @legacy_io attributes {push_constants = 2 : i32, sym_visibility = "private"} {
    hal.interface.binding @arg1, set=0, binding=0, type="StorageBuffer", access="Read"
    hal.interface.binding @arg2, set=0, binding=1, type="StorageBuffer", access="Read"
    hal.interface.binding @ret0, set=0, binding=2, type="StorageBuffer", access="Write|Discard"
  }
}

Dominance relationships can be fragile

This is more of an "I've spent a lot of time on this" vs a discrete example. Since most lowerings involve needing to get the shape of a result to generate code ops that feed into the input, it is easy to end up in situations where the IR is not ordered legally.

Hard to transform

If a pass needs to be "shape aware", it must be taught how to walk around the tie_shape ops. In this example, I've called them "identity metadata" ops and taught the pass to ensure that they are properly hoisted at the boundaries of dispatch regions (with bugs I might add: see the canonicalize-cse-canonicalize cycle above that I am presently working around it with).

I feel like what I actually wish I were reaching for here would be an analysis that held all shape associations that I could query/update. However, because this involves actually re-ifying the shapes, there is no place to "terminate" the SSA values representing shape components, making such an analysis hard to conceive of. I wish there was a less intrusive way to "terminate" these reified shapes for a period of time and then drop them vs needing to teach multiple layers of machinery about tie_shape or suffering the phase ordering issues of only being able to reason about shapes once anchored on allocs (and presumed available in a runtime struct).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment