Skip to content

Instantly share code, notes, and snippets.

@bjacob
Last active April 24, 2026 20:21
Show Gist options
  • Select an option

  • Save bjacob/143f7d4ea4f07f94b7403ca5a5a1b737 to your computer and use it in GitHub Desktop.

Select an option

Save bjacob/143f7d4ea4f07f94b7403ca5a5a1b737 to your computer and use it in GitHub Desktop.

[mlir][vector] Gather1DToConditionalLoads emits ~8x more integer ops per lane for rank>1 memrefs after #184706

Summary

Commit 305dc4e5a9a6 (PR #184706, "Lower vector.gather with delinearization approach") makes Gather1DToConditionalLoads unconditionally compute load indices for rank > 1 memrefs as

flatIdx        = linearize(offsets, shape) + gatherIdx
loadIndices    = delinearize(flatIdx, shape)

This is correct in general (fixes a real bug for strided/non-contiguous memrefs) but it is a strict regression for the common case where only the innermost memref dim varies across lanes — which is the vast majority of vectorised vector.gathers produced from tensor.extract by IREE's LLVMCPU pipeline.

Before the change, each lane emitted a single addi to offset the innermost dim. After the change, each lane additionally emits a full affine.delinearize_index that, after -lower-affine, expands to floordivsi + divsi + 2·remsi + 2·cmpi + 2·select + 2·addi, all of which operate on signed index integers and include two integer divisions on the critical path. On x86 with AVX2 (8-lane gather, v3 cost model), this multiplies the integer work per gather by ~8x.

Measured on an end-to-end IREE model (a customer workload that gathers from 2-D/3-D index tables inside a linalg.generic):

LLVM commit iree-benchmark-module time
b7c4615e1378 (parent) 1.94 ms
305dc4e5a9a6 (this PR) 8.54 ms

A minimal MLIR reproducer below produces the same worse IR directly out of a single MLIR pass, with no IREE-specific context or runtime needed.


Reproducer

Input — a single vector.gather from a static rank-3 memref:

// gather_input.mlir
func.func @gather_memref_3d(%base: memref<50x40x40xi8>,
                            %idx: vector<8xindex>,
                            %mask: vector<8xi1>,
                            %pass: vector<8xi8>,
                            %i: index, %j: index, %k: index) -> vector<8xi8> {
  %0 = vector.gather %base[%i, %j, %k] [%idx], %mask, %pass
     : memref<50x40x40xi8>, vector<8xindex>, vector<8xi1>, vector<8xi8>
       into vector<8xi8>
  return %0 : vector<8xi8>
}

Apply the gather lowering (via any pipeline that calls populateVectorGatherLoweringPatterns — e.g. upstream -test-vector-gather-lowering, or IREE's iree-llvmcpu-virtual-vector-lowering):

iree-opt --pass-pipeline='builtin.module(func.func(iree-llvmcpu-virtual-vector-lowering))' gather_input.mlir

Output BEFORE #184706 (lane 0 shown; all 8 lanes identical in shape)

%1 = vector.extract %arg1[0] : index from vector<8xindex>
%2 = arith.addi %arg6, %1 : index                 // <-- 1 add per lane
%3 = scf.if %0 -> (vector<8xi8>) {
  %32 = vector.load %arg0[%arg4, %arg5, %2]       // direct 3-D load
      : memref<50x40x40xi8>, vector<1xi8>
  ...
}

Output AFTER #184706 (lane 0 shown; repeated for all 8 lanes)

Before -lower-affine:

%0 = affine.linearize_index [%arg4, %arg5, %arg6] by (50, 40, 40) : index  // hoisted, once
...
%2 = vector.extract %arg1[0] : index from vector<8xindex>
%3 = arith.addi %0, %2 : index
%4:3 = affine.delinearize_index %3 into (50, 40, 40)   // <-- NEW, per-lane
%5 = scf.if %1 -> (vector<8xi8>) {
  %41 = vector.load %arg0[%4#0, %4#1, %4#2]
      : memref<50x40x40xi8>, vector<1xi8>
  ...
}

After -lower-affine (what actually reaches LLVM):

%6  = arith.addi %3, %5  : index                 // gather index + linear offset
%7  = arith.floordivsi %6, %c1600 : index        // \
%8  = arith.remsi      %6, %c1600 : index        //  \
%9  = arith.cmpi slt, %8, %c0 : index            //   \
%10 = arith.addi %8, %c1600 overflow<nsw> : index//    >  NEW per lane:
%11 = arith.select %9, %10, %8 : index           //    /  1 floordivsi
%12 = arith.divsi %11, %c40 : index              //   /   1 divsi
%13 = arith.remsi %6, %c40 : index               //  /    2 remsi
%14 = arith.cmpi slt, %13, %c0 : index           // /     2 cmpi + select
%15 = arith.addi %13, %c40 overflow<nsw> : index //
%16 = arith.select %14, %15, %13 : index         //
%17 = scf.if %4 -> (vector<8xi8>) {
  %116 = vector.load %arg0[%7, %12, %16] ...     // lane load
  ...
}

Per-lane integer op count: 1 before → 10 after (two of which are signed integer divisions, which cost tens of cycles on x86). Multiplied by 8 lanes, that's 8 vs 80 integer ops per gather. The IR size of the lowered function grows from 97 → 207 lines for this single gather.

Why linearizeadddelinearize does not cancel

For the very common case where the base offsets (%arg4, %arg5, %arg6) are arbitrary SSA values (e.g. IVs of a parallel linalg.generic), the compiler cannot in general prove the result of `delinearize(linearize(o)

  • k)equals(o_0, o_1, o_2 + k)without bounds information on the offsets. Upstreamaffine` folds and arithmetic canonicalisers do not eliminate this pattern in practice. The produced IR therefore reaches the backend with the full delinearisation arithmetic still present, and in IREE's pipeline this ends up in the innermost vectorised loop body.

Root-cause in the pattern

From mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp (Gather1DToConditionalLoads), introduced by #184706:

if (memType.getRank() > 1)
  useDelinearization = true;
...
if (useDelinearization) {
  Value flatIdx =
      rewriter.createOrFold<arith::AddIOp>(loc, linearizedOffsets, index);
  auto delinOp = affine::AffineDelinearizeIndexOp::create(
      rewriter, loc, flatIdx, baseShape, /*hasOuterBound=*/true);
  for (int64_t d = 0, rank = loadOffsets.size(); d < rank; ++d)
    loadOffsets[d] = delinOp.getResult(d);
} else {
  loadOffsets.back() =
      rewriter.createOrFold<arith::AddIOp>(loc, lastLoadOffset, index);
}

The useDelinearization branch is taken unconditionally for all rank > 1 memrefs, even in the case that the old branch handled correctly: contiguous memref with unit innermost stride, where only the innermost dim is offset by the gather index. The correctness fix that motivated the PR concerns strided / non-unit-innermost-stride memrefs and cases where gatherIdx could carry across dims; for the contiguous rank-N, unit-innermost-stride case that is not the case.

Suggested fix (sketch)

Guard useDelinearization on something stronger than memType.getRank() > 1. Conservative options, in increasing order of complexity:

  1. Only take the delinearization path when the memref is not a plain identity-layout memref with unit innermost stride.
  2. Additionally, only take it when the innermost dim's size is smaller than the maximum possible offsets.back() + max(gatherIdx) that the analysis can prove — i.e. when carry across dims is actually possible.
  3. Keep the old single-addi path by default, and surface the delinearizing path behind a pattern option, used by clients that actually need it.

For the contiguous / unit-innermost-stride case (which is what tensor-backed IR vectorised from tensor.extract produces in IREE), option 1 is sufficient to close the regression and preserves the correctness fix for the strided case.

Files

  • gather_input.mlir — the reproducer input above.
  • gather_before.mlir — full lowered IR with #184706 reverted.
  • gather_after.mlir — full lowered IR with #184706 applied.
  • gather_after_lowered.mlir — same but with -lower-affine applied, showing the arithmetic that actually reaches the backend.
module {
func.func private @_load_dialects() {
return
}
func.func @gather_memref_3d(%arg0: memref<50x40x40xi8>, %arg1: vector<8xindex>, %arg2: vector<8xi1>, %arg3: vector<8xi8>, %arg4: index, %arg5: index, %arg6: index) -> vector<8xi8> {
%0 = affine.linearize_index [%arg4, %arg5, %arg6] by (50, 40, 40) : index
%1 = vector.extract %arg2[0] : i1 from vector<8xi1>
%2 = vector.extract %arg1[0] : index from vector<8xindex>
%3 = arith.addi %0, %2 : index
%4:3 = affine.delinearize_index %3 into (50, 40, 40) : index, index, index
%5 = scf.if %1 -> (vector<8xi8>) {
%41 = vector.load %arg0[%4#0, %4#1, %4#2] : memref<50x40x40xi8>, vector<1xi8>
%42 = vector.extract %41[0] : i8 from vector<1xi8>
%43 = vector.insert %42, %arg3 [0] : i8 into vector<8xi8>
scf.yield %43 : vector<8xi8>
} else {
scf.yield %arg3 : vector<8xi8>
}
%6 = vector.extract %arg2[1] : i1 from vector<8xi1>
%7 = vector.extract %arg1[1] : index from vector<8xindex>
%8 = arith.addi %0, %7 : index
%9:3 = affine.delinearize_index %8 into (50, 40, 40) : index, index, index
%10 = scf.if %6 -> (vector<8xi8>) {
%41 = vector.load %arg0[%9#0, %9#1, %9#2] : memref<50x40x40xi8>, vector<1xi8>
%42 = vector.extract %41[0] : i8 from vector<1xi8>
%43 = vector.insert %42, %5 [1] : i8 into vector<8xi8>
scf.yield %43 : vector<8xi8>
} else {
scf.yield %5 : vector<8xi8>
}
%11 = vector.extract %arg2[2] : i1 from vector<8xi1>
%12 = vector.extract %arg1[2] : index from vector<8xindex>
%13 = arith.addi %0, %12 : index
%14:3 = affine.delinearize_index %13 into (50, 40, 40) : index, index, index
%15 = scf.if %11 -> (vector<8xi8>) {
%41 = vector.load %arg0[%14#0, %14#1, %14#2] : memref<50x40x40xi8>, vector<1xi8>
%42 = vector.extract %41[0] : i8 from vector<1xi8>
%43 = vector.insert %42, %10 [2] : i8 into vector<8xi8>
scf.yield %43 : vector<8xi8>
} else {
scf.yield %10 : vector<8xi8>
}
%16 = vector.extract %arg2[3] : i1 from vector<8xi1>
%17 = vector.extract %arg1[3] : index from vector<8xindex>
%18 = arith.addi %0, %17 : index
%19:3 = affine.delinearize_index %18 into (50, 40, 40) : index, index, index
%20 = scf.if %16 -> (vector<8xi8>) {
%41 = vector.load %arg0[%19#0, %19#1, %19#2] : memref<50x40x40xi8>, vector<1xi8>
%42 = vector.extract %41[0] : i8 from vector<1xi8>
%43 = vector.insert %42, %15 [3] : i8 into vector<8xi8>
scf.yield %43 : vector<8xi8>
} else {
scf.yield %15 : vector<8xi8>
}
%21 = vector.extract %arg2[4] : i1 from vector<8xi1>
%22 = vector.extract %arg1[4] : index from vector<8xindex>
%23 = arith.addi %0, %22 : index
%24:3 = affine.delinearize_index %23 into (50, 40, 40) : index, index, index
%25 = scf.if %21 -> (vector<8xi8>) {
%41 = vector.load %arg0[%24#0, %24#1, %24#2] : memref<50x40x40xi8>, vector<1xi8>
%42 = vector.extract %41[0] : i8 from vector<1xi8>
%43 = vector.insert %42, %20 [4] : i8 into vector<8xi8>
scf.yield %43 : vector<8xi8>
} else {
scf.yield %20 : vector<8xi8>
}
%26 = vector.extract %arg2[5] : i1 from vector<8xi1>
%27 = vector.extract %arg1[5] : index from vector<8xindex>
%28 = arith.addi %0, %27 : index
%29:3 = affine.delinearize_index %28 into (50, 40, 40) : index, index, index
%30 = scf.if %26 -> (vector<8xi8>) {
%41 = vector.load %arg0[%29#0, %29#1, %29#2] : memref<50x40x40xi8>, vector<1xi8>
%42 = vector.extract %41[0] : i8 from vector<1xi8>
%43 = vector.insert %42, %25 [5] : i8 into vector<8xi8>
scf.yield %43 : vector<8xi8>
} else {
scf.yield %25 : vector<8xi8>
}
%31 = vector.extract %arg2[6] : i1 from vector<8xi1>
%32 = vector.extract %arg1[6] : index from vector<8xindex>
%33 = arith.addi %0, %32 : index
%34:3 = affine.delinearize_index %33 into (50, 40, 40) : index, index, index
%35 = scf.if %31 -> (vector<8xi8>) {
%41 = vector.load %arg0[%34#0, %34#1, %34#2] : memref<50x40x40xi8>, vector<1xi8>
%42 = vector.extract %41[0] : i8 from vector<1xi8>
%43 = vector.insert %42, %30 [6] : i8 into vector<8xi8>
scf.yield %43 : vector<8xi8>
} else {
scf.yield %30 : vector<8xi8>
}
%36 = vector.extract %arg2[7] : i1 from vector<8xi1>
%37 = vector.extract %arg1[7] : index from vector<8xindex>
%38 = arith.addi %0, %37 : index
%39:3 = affine.delinearize_index %38 into (50, 40, 40) : index, index, index
%40 = scf.if %36 -> (vector<8xi8>) {
%41 = vector.load %arg0[%39#0, %39#1, %39#2] : memref<50x40x40xi8>, vector<1xi8>
%42 = vector.extract %41[0] : i8 from vector<1xi8>
%43 = vector.insert %42, %35 [7] : i8 into vector<8xi8>
scf.yield %43 : vector<8xi8>
} else {
scf.yield %35 : vector<8xi8>
}
return %40 : vector<8xi8>
}
}
module {
func.func private @_load_dialects() {
return
}
func.func @gather_memref_3d(%arg0: memref<50x40x40xi8>, %arg1: vector<8xindex>, %arg2: vector<8xi1>, %arg3: vector<8xi8>, %arg4: index, %arg5: index, %arg6: index) -> vector<8xi8> {
%c40 = arith.constant 40 : index
%c1600 = arith.constant 1600 : index
%0 = arith.muli %arg4, %c1600 overflow<nsw> : index
%1 = arith.muli %arg5, %c40 overflow<nsw> : index
%2 = arith.addi %0, %1 overflow<nsw> : index
%3 = arith.addi %2, %arg6 overflow<nsw> : index
%4 = vector.extract %arg2[0] : i1 from vector<8xi1>
%5 = vector.extract %arg1[0] : index from vector<8xindex>
%6 = arith.addi %3, %5 : index
%c40_0 = arith.constant 40 : index
%c1600_1 = arith.constant 1600 : index
%c0 = arith.constant 0 : index
%7 = arith.floordivsi %6, %c1600_1 : index
%8 = arith.remsi %6, %c1600_1 : index
%9 = arith.cmpi slt, %8, %c0 : index
%10 = arith.addi %8, %c1600_1 overflow<nsw> : index
%11 = arith.select %9, %10, %8 : index
%12 = arith.divsi %11, %c40_0 : index
%13 = arith.remsi %6, %c40_0 : index
%14 = arith.cmpi slt, %13, %c0 : index
%15 = arith.addi %13, %c40_0 overflow<nsw> : index
%16 = arith.select %14, %15, %13 : index
%17 = scf.if %4 -> (vector<8xi8>) {
%116 = vector.load %arg0[%7, %12, %16] : memref<50x40x40xi8>, vector<1xi8>
%117 = vector.extract %116[0] : i8 from vector<1xi8>
%118 = vector.insert %117, %arg3 [0] : i8 into vector<8xi8>
scf.yield %118 : vector<8xi8>
} else {
scf.yield %arg3 : vector<8xi8>
}
%18 = vector.extract %arg2[1] : i1 from vector<8xi1>
%19 = vector.extract %arg1[1] : index from vector<8xindex>
%20 = arith.addi %3, %19 : index
%c40_2 = arith.constant 40 : index
%c1600_3 = arith.constant 1600 : index
%c0_4 = arith.constant 0 : index
%21 = arith.floordivsi %20, %c1600_3 : index
%22 = arith.remsi %20, %c1600_3 : index
%23 = arith.cmpi slt, %22, %c0_4 : index
%24 = arith.addi %22, %c1600_3 overflow<nsw> : index
%25 = arith.select %23, %24, %22 : index
%26 = arith.divsi %25, %c40_2 : index
%27 = arith.remsi %20, %c40_2 : index
%28 = arith.cmpi slt, %27, %c0_4 : index
%29 = arith.addi %27, %c40_2 overflow<nsw> : index
%30 = arith.select %28, %29, %27 : index
%31 = scf.if %18 -> (vector<8xi8>) {
%116 = vector.load %arg0[%21, %26, %30] : memref<50x40x40xi8>, vector<1xi8>
%117 = vector.extract %116[0] : i8 from vector<1xi8>
%118 = vector.insert %117, %17 [1] : i8 into vector<8xi8>
scf.yield %118 : vector<8xi8>
} else {
scf.yield %17 : vector<8xi8>
}
%32 = vector.extract %arg2[2] : i1 from vector<8xi1>
%33 = vector.extract %arg1[2] : index from vector<8xindex>
%34 = arith.addi %3, %33 : index
%c40_5 = arith.constant 40 : index
%c1600_6 = arith.constant 1600 : index
%c0_7 = arith.constant 0 : index
%35 = arith.floordivsi %34, %c1600_6 : index
%36 = arith.remsi %34, %c1600_6 : index
%37 = arith.cmpi slt, %36, %c0_7 : index
%38 = arith.addi %36, %c1600_6 overflow<nsw> : index
%39 = arith.select %37, %38, %36 : index
%40 = arith.divsi %39, %c40_5 : index
%41 = arith.remsi %34, %c40_5 : index
%42 = arith.cmpi slt, %41, %c0_7 : index
%43 = arith.addi %41, %c40_5 overflow<nsw> : index
%44 = arith.select %42, %43, %41 : index
%45 = scf.if %32 -> (vector<8xi8>) {
%116 = vector.load %arg0[%35, %40, %44] : memref<50x40x40xi8>, vector<1xi8>
%117 = vector.extract %116[0] : i8 from vector<1xi8>
%118 = vector.insert %117, %31 [2] : i8 into vector<8xi8>
scf.yield %118 : vector<8xi8>
} else {
scf.yield %31 : vector<8xi8>
}
%46 = vector.extract %arg2[3] : i1 from vector<8xi1>
%47 = vector.extract %arg1[3] : index from vector<8xindex>
%48 = arith.addi %3, %47 : index
%c40_8 = arith.constant 40 : index
%c1600_9 = arith.constant 1600 : index
%c0_10 = arith.constant 0 : index
%49 = arith.floordivsi %48, %c1600_9 : index
%50 = arith.remsi %48, %c1600_9 : index
%51 = arith.cmpi slt, %50, %c0_10 : index
%52 = arith.addi %50, %c1600_9 overflow<nsw> : index
%53 = arith.select %51, %52, %50 : index
%54 = arith.divsi %53, %c40_8 : index
%55 = arith.remsi %48, %c40_8 : index
%56 = arith.cmpi slt, %55, %c0_10 : index
%57 = arith.addi %55, %c40_8 overflow<nsw> : index
%58 = arith.select %56, %57, %55 : index
%59 = scf.if %46 -> (vector<8xi8>) {
%116 = vector.load %arg0[%49, %54, %58] : memref<50x40x40xi8>, vector<1xi8>
%117 = vector.extract %116[0] : i8 from vector<1xi8>
%118 = vector.insert %117, %45 [3] : i8 into vector<8xi8>
scf.yield %118 : vector<8xi8>
} else {
scf.yield %45 : vector<8xi8>
}
%60 = vector.extract %arg2[4] : i1 from vector<8xi1>
%61 = vector.extract %arg1[4] : index from vector<8xindex>
%62 = arith.addi %3, %61 : index
%c40_11 = arith.constant 40 : index
%c1600_12 = arith.constant 1600 : index
%c0_13 = arith.constant 0 : index
%63 = arith.floordivsi %62, %c1600_12 : index
%64 = arith.remsi %62, %c1600_12 : index
%65 = arith.cmpi slt, %64, %c0_13 : index
%66 = arith.addi %64, %c1600_12 overflow<nsw> : index
%67 = arith.select %65, %66, %64 : index
%68 = arith.divsi %67, %c40_11 : index
%69 = arith.remsi %62, %c40_11 : index
%70 = arith.cmpi slt, %69, %c0_13 : index
%71 = arith.addi %69, %c40_11 overflow<nsw> : index
%72 = arith.select %70, %71, %69 : index
%73 = scf.if %60 -> (vector<8xi8>) {
%116 = vector.load %arg0[%63, %68, %72] : memref<50x40x40xi8>, vector<1xi8>
%117 = vector.extract %116[0] : i8 from vector<1xi8>
%118 = vector.insert %117, %59 [4] : i8 into vector<8xi8>
scf.yield %118 : vector<8xi8>
} else {
scf.yield %59 : vector<8xi8>
}
%74 = vector.extract %arg2[5] : i1 from vector<8xi1>
%75 = vector.extract %arg1[5] : index from vector<8xindex>
%76 = arith.addi %3, %75 : index
%c40_14 = arith.constant 40 : index
%c1600_15 = arith.constant 1600 : index
%c0_16 = arith.constant 0 : index
%77 = arith.floordivsi %76, %c1600_15 : index
%78 = arith.remsi %76, %c1600_15 : index
%79 = arith.cmpi slt, %78, %c0_16 : index
%80 = arith.addi %78, %c1600_15 overflow<nsw> : index
%81 = arith.select %79, %80, %78 : index
%82 = arith.divsi %81, %c40_14 : index
%83 = arith.remsi %76, %c40_14 : index
%84 = arith.cmpi slt, %83, %c0_16 : index
%85 = arith.addi %83, %c40_14 overflow<nsw> : index
%86 = arith.select %84, %85, %83 : index
%87 = scf.if %74 -> (vector<8xi8>) {
%116 = vector.load %arg0[%77, %82, %86] : memref<50x40x40xi8>, vector<1xi8>
%117 = vector.extract %116[0] : i8 from vector<1xi8>
%118 = vector.insert %117, %73 [5] : i8 into vector<8xi8>
scf.yield %118 : vector<8xi8>
} else {
scf.yield %73 : vector<8xi8>
}
%88 = vector.extract %arg2[6] : i1 from vector<8xi1>
%89 = vector.extract %arg1[6] : index from vector<8xindex>
%90 = arith.addi %3, %89 : index
%c40_17 = arith.constant 40 : index
%c1600_18 = arith.constant 1600 : index
%c0_19 = arith.constant 0 : index
%91 = arith.floordivsi %90, %c1600_18 : index
%92 = arith.remsi %90, %c1600_18 : index
%93 = arith.cmpi slt, %92, %c0_19 : index
%94 = arith.addi %92, %c1600_18 overflow<nsw> : index
%95 = arith.select %93, %94, %92 : index
%96 = arith.divsi %95, %c40_17 : index
%97 = arith.remsi %90, %c40_17 : index
%98 = arith.cmpi slt, %97, %c0_19 : index
%99 = arith.addi %97, %c40_17 overflow<nsw> : index
%100 = arith.select %98, %99, %97 : index
%101 = scf.if %88 -> (vector<8xi8>) {
%116 = vector.load %arg0[%91, %96, %100] : memref<50x40x40xi8>, vector<1xi8>
%117 = vector.extract %116[0] : i8 from vector<1xi8>
%118 = vector.insert %117, %87 [6] : i8 into vector<8xi8>
scf.yield %118 : vector<8xi8>
} else {
scf.yield %87 : vector<8xi8>
}
%102 = vector.extract %arg2[7] : i1 from vector<8xi1>
%103 = vector.extract %arg1[7] : index from vector<8xindex>
%104 = arith.addi %3, %103 : index
%c40_20 = arith.constant 40 : index
%c1600_21 = arith.constant 1600 : index
%c0_22 = arith.constant 0 : index
%105 = arith.floordivsi %104, %c1600_21 : index
%106 = arith.remsi %104, %c1600_21 : index
%107 = arith.cmpi slt, %106, %c0_22 : index
%108 = arith.addi %106, %c1600_21 overflow<nsw> : index
%109 = arith.select %107, %108, %106 : index
%110 = arith.divsi %109, %c40_20 : index
%111 = arith.remsi %104, %c40_20 : index
%112 = arith.cmpi slt, %111, %c0_22 : index
%113 = arith.addi %111, %c40_20 overflow<nsw> : index
%114 = arith.select %112, %113, %111 : index
%115 = scf.if %102 -> (vector<8xi8>) {
%116 = vector.load %arg0[%105, %110, %114] : memref<50x40x40xi8>, vector<1xi8>
%117 = vector.extract %116[0] : i8 from vector<1xi8>
%118 = vector.insert %117, %101 [7] : i8 into vector<8xi8>
scf.yield %118 : vector<8xi8>
} else {
scf.yield %101 : vector<8xi8>
}
return %115 : vector<8xi8>
}
}
module {
func.func private @_load_dialects() {
return
}
func.func @gather_memref_3d(%arg0: memref<50x40x40xi8>, %arg1: vector<8xindex>, %arg2: vector<8xi1>, %arg3: vector<8xi8>, %arg4: index, %arg5: index, %arg6: index) -> vector<8xi8> {
%0 = vector.extract %arg2[0] : i1 from vector<8xi1>
%1 = vector.extract %arg1[0] : index from vector<8xindex>
%2 = arith.addi %arg6, %1 : index
%3 = scf.if %0 -> (vector<8xi8>) {
%32 = vector.load %arg0[%arg4, %arg5, %2] : memref<50x40x40xi8>, vector<1xi8>
%33 = vector.extract %32[0] : i8 from vector<1xi8>
%34 = vector.insert %33, %arg3 [0] : i8 into vector<8xi8>
scf.yield %34 : vector<8xi8>
} else {
scf.yield %arg3 : vector<8xi8>
}
%4 = vector.extract %arg2[1] : i1 from vector<8xi1>
%5 = vector.extract %arg1[1] : index from vector<8xindex>
%6 = arith.addi %arg6, %5 : index
%7 = scf.if %4 -> (vector<8xi8>) {
%32 = vector.load %arg0[%arg4, %arg5, %6] : memref<50x40x40xi8>, vector<1xi8>
%33 = vector.extract %32[0] : i8 from vector<1xi8>
%34 = vector.insert %33, %3 [1] : i8 into vector<8xi8>
scf.yield %34 : vector<8xi8>
} else {
scf.yield %3 : vector<8xi8>
}
%8 = vector.extract %arg2[2] : i1 from vector<8xi1>
%9 = vector.extract %arg1[2] : index from vector<8xindex>
%10 = arith.addi %arg6, %9 : index
%11 = scf.if %8 -> (vector<8xi8>) {
%32 = vector.load %arg0[%arg4, %arg5, %10] : memref<50x40x40xi8>, vector<1xi8>
%33 = vector.extract %32[0] : i8 from vector<1xi8>
%34 = vector.insert %33, %7 [2] : i8 into vector<8xi8>
scf.yield %34 : vector<8xi8>
} else {
scf.yield %7 : vector<8xi8>
}
%12 = vector.extract %arg2[3] : i1 from vector<8xi1>
%13 = vector.extract %arg1[3] : index from vector<8xindex>
%14 = arith.addi %arg6, %13 : index
%15 = scf.if %12 -> (vector<8xi8>) {
%32 = vector.load %arg0[%arg4, %arg5, %14] : memref<50x40x40xi8>, vector<1xi8>
%33 = vector.extract %32[0] : i8 from vector<1xi8>
%34 = vector.insert %33, %11 [3] : i8 into vector<8xi8>
scf.yield %34 : vector<8xi8>
} else {
scf.yield %11 : vector<8xi8>
}
%16 = vector.extract %arg2[4] : i1 from vector<8xi1>
%17 = vector.extract %arg1[4] : index from vector<8xindex>
%18 = arith.addi %arg6, %17 : index
%19 = scf.if %16 -> (vector<8xi8>) {
%32 = vector.load %arg0[%arg4, %arg5, %18] : memref<50x40x40xi8>, vector<1xi8>
%33 = vector.extract %32[0] : i8 from vector<1xi8>
%34 = vector.insert %33, %15 [4] : i8 into vector<8xi8>
scf.yield %34 : vector<8xi8>
} else {
scf.yield %15 : vector<8xi8>
}
%20 = vector.extract %arg2[5] : i1 from vector<8xi1>
%21 = vector.extract %arg1[5] : index from vector<8xindex>
%22 = arith.addi %arg6, %21 : index
%23 = scf.if %20 -> (vector<8xi8>) {
%32 = vector.load %arg0[%arg4, %arg5, %22] : memref<50x40x40xi8>, vector<1xi8>
%33 = vector.extract %32[0] : i8 from vector<1xi8>
%34 = vector.insert %33, %19 [5] : i8 into vector<8xi8>
scf.yield %34 : vector<8xi8>
} else {
scf.yield %19 : vector<8xi8>
}
%24 = vector.extract %arg2[6] : i1 from vector<8xi1>
%25 = vector.extract %arg1[6] : index from vector<8xindex>
%26 = arith.addi %arg6, %25 : index
%27 = scf.if %24 -> (vector<8xi8>) {
%32 = vector.load %arg0[%arg4, %arg5, %26] : memref<50x40x40xi8>, vector<1xi8>
%33 = vector.extract %32[0] : i8 from vector<1xi8>
%34 = vector.insert %33, %23 [6] : i8 into vector<8xi8>
scf.yield %34 : vector<8xi8>
} else {
scf.yield %23 : vector<8xi8>
}
%28 = vector.extract %arg2[7] : i1 from vector<8xi1>
%29 = vector.extract %arg1[7] : index from vector<8xindex>
%30 = arith.addi %arg6, %29 : index
%31 = scf.if %28 -> (vector<8xi8>) {
%32 = vector.load %arg0[%arg4, %arg5, %30] : memref<50x40x40xi8>, vector<1xi8>
%33 = vector.extract %32[0] : i8 from vector<1xi8>
%34 = vector.insert %33, %27 [7] : i8 into vector<8xi8>
scf.yield %34 : vector<8xi8>
} else {
scf.yield %27 : vector<8xi8>
}
return %31 : vector<8xi8>
}
}
// Minimal input reproducing the pattern modified by LLVM PR #184706
// ("[mlir][vector] Lower vector.gather with delinearization approach").
//
// Reference: mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp,
// pattern `Gather1DToConditionalLoads`.
// Force-load scf/affine/memref dialects so the gather-lowering pass
// can emit its generated ops.
func.func private @_load_dialects() {
%c = arith.constant false
scf.if %c { }
%a = affine.apply affine_map<() -> (0)> ()
return
}
func.func @gather_memref_3d(%base: memref<50x40x40xi8>,
%idx: vector<8xindex>,
%mask: vector<8xi1>,
%pass: vector<8xi8>,
%i: index, %j: index, %k: index) -> vector<8xi8> {
%0 = vector.gather %base[%i, %j, %k] [%idx], %mask, %pass
: memref<50x40x40xi8>, vector<8xindex>, vector<8xi1>, vector<8xi8>
into vector<8xi8>
return %0 : vector<8xi8>
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment