[mlir][vector] Gather1DToConditionalLoads emits ~8x more integer ops per lane for rank>1 memrefs after #184706
Commit
305dc4e5a9a6
(PR #184706, "Lower
vector.gather with delinearization approach") makes
Gather1DToConditionalLoads unconditionally compute load indices for
rank > 1 memrefs as
flatIdx = linearize(offsets, shape) + gatherIdx
loadIndices = delinearize(flatIdx, shape)
This is correct in general (fixes a real bug for strided/non-contiguous
memrefs) but it is a strict regression for the common case where only
the innermost memref dim varies across lanes — which is the vast majority
of vectorised vector.gathers produced from tensor.extract by IREE's
LLVMCPU pipeline.
Before the change, each lane emitted a single addi to offset the
innermost dim. After the change, each lane additionally emits a full
affine.delinearize_index that, after -lower-affine, expands to
floordivsi + divsi + 2·remsi + 2·cmpi + 2·select + 2·addi, all of
which operate on signed index integers and include two integer
divisions on the critical path. On x86 with AVX2 (8-lane gather,
v3 cost model), this multiplies the integer work per gather by ~8x.
Measured on an end-to-end IREE model (a customer workload that gathers
from 2-D/3-D index tables inside a linalg.generic):
| LLVM commit | iree-benchmark-module time |
|---|---|
b7c4615e1378 (parent) |
1.94 ms |
305dc4e5a9a6 (this PR) |
8.54 ms |
A minimal MLIR reproducer below produces the same worse IR directly out of a single MLIR pass, with no IREE-specific context or runtime needed.
Input — a single vector.gather from a static rank-3 memref:
// gather_input.mlir
func.func @gather_memref_3d(%base: memref<50x40x40xi8>,
%idx: vector<8xindex>,
%mask: vector<8xi1>,
%pass: vector<8xi8>,
%i: index, %j: index, %k: index) -> vector<8xi8> {
%0 = vector.gather %base[%i, %j, %k] [%idx], %mask, %pass
: memref<50x40x40xi8>, vector<8xindex>, vector<8xi1>, vector<8xi8>
into vector<8xi8>
return %0 : vector<8xi8>
}Apply the gather lowering (via any pipeline that calls
populateVectorGatherLoweringPatterns — e.g. upstream
-test-vector-gather-lowering, or IREE's
iree-llvmcpu-virtual-vector-lowering):
iree-opt --pass-pipeline='builtin.module(func.func(iree-llvmcpu-virtual-vector-lowering))' gather_input.mlir
%1 = vector.extract %arg1[0] : index from vector<8xindex>
%2 = arith.addi %arg6, %1 : index // <-- 1 add per lane
%3 = scf.if %0 -> (vector<8xi8>) {
%32 = vector.load %arg0[%arg4, %arg5, %2] // direct 3-D load
: memref<50x40x40xi8>, vector<1xi8>
...
}Before -lower-affine:
%0 = affine.linearize_index [%arg4, %arg5, %arg6] by (50, 40, 40) : index // hoisted, once
...
%2 = vector.extract %arg1[0] : index from vector<8xindex>
%3 = arith.addi %0, %2 : index
%4:3 = affine.delinearize_index %3 into (50, 40, 40) // <-- NEW, per-lane
%5 = scf.if %1 -> (vector<8xi8>) {
%41 = vector.load %arg0[%4#0, %4#1, %4#2]
: memref<50x40x40xi8>, vector<1xi8>
...
}After -lower-affine (what actually reaches LLVM):
%6 = arith.addi %3, %5 : index // gather index + linear offset
%7 = arith.floordivsi %6, %c1600 : index // \
%8 = arith.remsi %6, %c1600 : index // \
%9 = arith.cmpi slt, %8, %c0 : index // \
%10 = arith.addi %8, %c1600 overflow<nsw> : index// > NEW per lane:
%11 = arith.select %9, %10, %8 : index // / 1 floordivsi
%12 = arith.divsi %11, %c40 : index // / 1 divsi
%13 = arith.remsi %6, %c40 : index // / 2 remsi
%14 = arith.cmpi slt, %13, %c0 : index // / 2 cmpi + select
%15 = arith.addi %13, %c40 overflow<nsw> : index //
%16 = arith.select %14, %15, %13 : index //
%17 = scf.if %4 -> (vector<8xi8>) {
%116 = vector.load %arg0[%7, %12, %16] ... // lane load
...
}Per-lane integer op count: 1 before → 10 after (two of which are signed integer divisions, which cost tens of cycles on x86). Multiplied by 8 lanes, that's 8 vs 80 integer ops per gather. The IR size of the lowered function grows from 97 → 207 lines for this single gather.
For the very common case where the base offsets (%arg4, %arg5, %arg6)
are arbitrary SSA values (e.g. IVs of a parallel linalg.generic), the
compiler cannot in general prove the result of `delinearize(linearize(o)
- k)
equals(o_0, o_1, o_2 + k)without bounds information on the offsets. Upstreamaffine` folds and arithmetic canonicalisers do not eliminate this pattern in practice. The produced IR therefore reaches the backend with the full delinearisation arithmetic still present, and in IREE's pipeline this ends up in the innermost vectorised loop body.
From mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp
(Gather1DToConditionalLoads), introduced by #184706:
if (memType.getRank() > 1)
useDelinearization = true;
...
if (useDelinearization) {
Value flatIdx =
rewriter.createOrFold<arith::AddIOp>(loc, linearizedOffsets, index);
auto delinOp = affine::AffineDelinearizeIndexOp::create(
rewriter, loc, flatIdx, baseShape, /*hasOuterBound=*/true);
for (int64_t d = 0, rank = loadOffsets.size(); d < rank; ++d)
loadOffsets[d] = delinOp.getResult(d);
} else {
loadOffsets.back() =
rewriter.createOrFold<arith::AddIOp>(loc, lastLoadOffset, index);
}The useDelinearization branch is taken unconditionally for all
rank > 1 memrefs, even in the case that the old branch handled
correctly: contiguous memref with unit innermost stride, where only the
innermost dim is offset by the gather index. The correctness fix that
motivated the PR concerns strided / non-unit-innermost-stride memrefs
and cases where gatherIdx could carry across dims; for the contiguous
rank-N, unit-innermost-stride case that is not the case.
Guard useDelinearization on something stronger than
memType.getRank() > 1. Conservative options, in increasing order of
complexity:
- Only take the delinearization path when the memref is not a plain identity-layout memref with unit innermost stride.
- Additionally, only take it when the innermost dim's size is smaller
than the maximum possible
offsets.back() + max(gatherIdx)that the analysis can prove — i.e. when carry across dims is actually possible. - Keep the old single-
addipath by default, and surface the delinearizing path behind a pattern option, used by clients that actually need it.
For the contiguous / unit-innermost-stride case (which is what
tensor-backed IR vectorised from tensor.extract produces in IREE),
option 1 is sufficient to close the regression and preserves the
correctness fix for the strided case.
gather_input.mlir— the reproducer input above.gather_before.mlir— full lowered IR with #184706 reverted.gather_after.mlir— full lowered IR with #184706 applied.gather_after_lowered.mlir— same but with-lower-affineapplied, showing the arithmetic that actually reaches the backend.