Generalize FoldMaskedTransferRaw and add FoldTransferReadOfEmptyTensor (#24301)
* Generalizes FoldMaskedTransferRAW from (masked, masked) to (unmasked,
unmasked), (unmasked, masked), (masked, unmasked).
* Adds pattern to fold transfer_read(tensor.empty)) -> ub.poison
* This allows intermediary index tensors to be folded after
vectorization.
* The test pipeline_vector_distribute_reduction_gfx942.mlir needed to be
corrected. The empty tensors are now folded, but the test was wrong,
online_attention should not have had empty tensors as operands to begin
with. All passes that create online_attention fill the operands with
either 0 or -1. So we do that here as well.
Fixes #24294
----
Assisted-By: Claude Opus 4.6
diff --git a/compiler/src/iree/compiler/Codegen/Common/OptimizeTensorInsertExtractSlices.cpp b/compiler/src/iree/compiler/Codegen/Common/OptimizeTensorInsertExtractSlices.cpp
index 0eabebe..ca3bf4a 100644
--- a/compiler/src/iree/compiler/Codegen/Common/OptimizeTensorInsertExtractSlices.cpp
+++ b/compiler/src/iree/compiler/Codegen/Common/OptimizeTensorInsertExtractSlices.cpp
@@ -256,77 +256,181 @@
}
};
-/// Folds IR resembling:
-/// ```
-/// %20 = vector.transfer_write %19, %16[%c0], %17 {in_bounds = [true]}
-// : vector<128xf16>, tensor<?xf16>
-// %21 = vector.transfer_read %20[%c0], %cst_2, %17
-/// : tensor<?xf16>, vector<128xf16>
-/// ```
-/// into a simpler masked vector.transfer_read.
+/// Folds a transfer_read that reads from the result of a transfer_write on
+/// the same region (Read-After-Write) into arithmetic on the written value,
+/// the original tensor, the masks, and the read's padding.
+///
+/// The general semantics are:
+///
+/// written_tensor[i] = wMask[i] ? valToStore[i] : original[i]
+/// result[i] = rMask[i] ? written_tensor[i] : rPad
+///
+/// Which gives:
+/// result = select(rMask, select(wMask, valToStore, original),
+/// broadcast(rPad))
+///
+/// Special cases avoid emitting unnecessary IR:
+/// - No wMask (unmasked write): wMask is implicitly all-true, inner select
+/// collapses to valToStore.
+/// - No rMask (unmasked read): rMask is implicitly all-true, outer select
+/// collapses away.
+/// - wMask == rMask: the original tensor is never needed (anywhere rMask is
+/// true, wMask is also true), so the inner select collapses to valToStore.
+///
/// After bufferization, this generally removes the need for materializing the
/// write to memory.
// TODO: Consider upstreaming
-struct FoldMaskedTransferRAW : OpRewritePattern<vector::TransferReadOp> {
+struct FoldTransferRAW : OpRewritePattern<vector::TransferReadOp> {
using Base::Base;
- LogicalResult matchAndRewrite(vector::TransferReadOp op,
+ LogicalResult matchAndRewrite(vector::TransferReadOp readOp,
PatternRewriter &rewriter) const override {
- // Fail to match if the read doesn't have pure tensor semantics.
- if (!op.hasPureTensorSemantics()) {
+ if (!readOp.hasPureTensorSemantics()) {
return failure();
}
- // Try to get the producing write op.
auto writeOp = dyn_cast_if_present<vector::TransferWriteOp>(
- op.getBase().getDefiningOp());
- // Fail to match if the write doesn't have pure tensor semantics.
+ readOp.getBase().getDefiningOp());
if (!writeOp || !writeOp.hasPureTensorSemantics()) {
return failure();
}
Value valToStore = writeOp.getValueToStore();
- // Fail to match if the in/out types are different
- if (valToStore.getType() != op.getType()) {
+ if (valToStore.getType() != readOp.getType()) {
return failure();
}
// Work only with trivial or equal indices.
- if ((llvm::any_of(op.getIndices(),
+ if ((llvm::any_of(readOp.getIndices(),
[](Value v) { return !isZeroInteger(v); }) ||
llvm::any_of(writeOp.getIndices(),
[](Value v) { return !isZeroInteger(v); })) &&
- (op.getIndices() != writeOp.getIndices())) {
+ (readOp.getIndices() != writeOp.getIndices())) {
return failure();
}
- // Work only with minor identity mappings.
- if (!op.getPermutationMap().isMinorIdentity() ||
+ if (!readOp.getPermutationMap().isMinorIdentity() ||
!writeOp.getPermutationMap().isMinorIdentity()) {
return failure();
}
TypedValue<VectorType> wMask = writeOp.getMask();
- Value rPad = op.getPadding();
+ TypedValue<VectorType> rMask = readOp.getMask();
- // Match only if the write and read op are masked and have the same mask.
- if (!wMask || (wMask != op.getMask())) {
+ // Build the inner value: select(wMask, valToStore, original).
+ // When wMask is absent (unmasked write) or wMask == rMask (original is
+ // never accessed), this simplifies to just valToStore.
+ Value inner = valToStore;
+ bool needsOriginal = wMask && wMask != rMask;
+ if (needsOriginal) {
+ Value originalRead = vector::TransferReadOp::create(
+ rewriter, readOp.getLoc(), readOp.getType(), writeOp.getBase(),
+ readOp.getIndices(), readOp.getPermutationMap(), readOp.getPadding(),
+ /*mask=*/Value(), readOp.getInBoundsAttr());
+ inner = arith::SelectOp::create(rewriter, readOp.getLoc(), wMask,
+ valToStore, originalRead);
+ }
+
+ if (!rMask) {
+ rewriter.replaceOp(readOp, inner);
+ return success();
+ }
+
+ // Build the outer value: select(rMask, inner, broadcast(rPad)).
+ // When rMask is absent (unmasked read), the result is just inner.
+ Value rPad = readOp.getPadding();
+ assert(!isa<VectorType>(rPad.getType()) &&
+ "masked transfers on vector element types are not supported; see "
+ "verifyTransferOp in upstream MLIR VectorOps.cpp");
+ Value padVal = vector::BroadcastOp::create(rewriter, rPad.getLoc(),
+ valToStore.getType(), rPad);
+ rewriter.replaceOpWithNewOp<arith::SelectOp>(readOp, rMask, inner, padVal);
+ return success();
+ }
+};
+
+/// Folds transfer_read(tensor.empty).
+///
+/// Since tensor.empty has unspecified contents, reading from it produces
+/// an unspecified value, which is exactly the semantics of ub.poison.
+/// Out of bounds means that pad is used.
+///
+/// Case 1 — fully in-bounds, no mask:
+/// %e = tensor.empty() : tensor<128xf16>
+/// %r = vector.transfer_read %e[%c0], %pad {in_bounds = [true]}
+/// ->
+/// %r = ub.poison : vector<128xf16>
+///
+/// Case 2 — fully in-bounds, masked:
+/// %e = tensor.empty() : tensor<128xf16>
+/// %r = vector.transfer_read %e[%c0], %pad, %mask {in_bounds = [true]}
+/// ->
+/// %poison = ub.poison : vector<128xf16>
+/// %bcast = vector.broadcast %pad : f16 to vector<128xf16>
+/// %r = arith.select %mask, %poison, %bcast
+///
+/// Case 3 — not fully in-bounds, no mask:
+/// %e = tensor.empty() : tensor<100xf16>
+/// %r = vector.transfer_read %e[%c0], %pad
+/// : tensor<100xf16>, vector<128xf16>
+/// ->
+/// %r = vector.broadcast %pad : f16 to vector<128xf16>
+///
+/// Case 4 — not fully in-bounds, masked:
+/// %e = tensor.empty() : tensor<100xf16>
+/// %r = vector.transfer_read %e[%c0], %pad, %mask
+/// : tensor<100xf16>, vector<128xf16>
+/// ->
+/// %r = vector.broadcast %pad : f16 to vector<128xf16>
+struct FoldTransferReadOfEmptyTensor
+ : OpRewritePattern<vector::TransferReadOp> {
+ using Base::Base;
+
+ LogicalResult matchAndRewrite(vector::TransferReadOp op,
+ PatternRewriter &rewriter) const override {
+ if (!op.hasPureTensorSemantics()) {
return failure();
}
- // NOTE[FoldMaskedTransferRAW]: since masking is not supported on shaped
- // types with vector element types (see `verifyTransferOp` in upstream MLIR
- // VectorOps.cpp), and the write op has a mask, it can be assumed `rPad`
- // never has a vector type. But for sanity add an assert in case things
- // change upstream.
- assert(!isa<VectorType>(rPad.getType()) &&
- "search `NOTE[FoldMaskedTransferRAW]` in "
- "GenericVectorization.cpp::FoldMaskedTransferRAW for information");
+ if (!op.getBase().getDefiningOp<tensor::EmptyOp>()) {
+ return failure();
+ }
- // Materialize the padding with a constant.
- auto padVal = vector::BroadcastOp::create(rewriter, rPad.getLoc(),
- valToStore.getType(), rPad);
- rewriter.replaceOpWithNewOp<arith::SelectOp>(op, wMask, valToStore, padVal);
+ if (!op.getPermutationMap().isMinorIdentity()) {
+ return failure();
+ }
+
+ bool fullyInBounds =
+ llvm::all_of(op.getInBoundsValues(), [](bool v) { return v; });
+ TypedValue<VectorType> mask = op.getMask();
+
+ if (mask && fullyInBounds) {
+ // Masked, fully in-bounds: mask-on lanes read unspecified contents
+ // (poison), mask-off lanes produce the padding value.
+ Value rPad = op.getPadding();
+ assert(!isa<VectorType>(rPad.getType()) &&
+ "masked transfers on vector element types are not supported; "
+ "see verifyTransferOp in upstream MLIR VectorOps.cpp");
+ Value poison = ub::PoisonOp::create(rewriter, op.getLoc(), op.getType());
+ Value padVal = vector::BroadcastOp::create(rewriter, rPad.getLoc(),
+ op.getType(), rPad);
+ rewriter.replaceOpWithNewOp<arith::SelectOp>(op, mask, poison, padVal);
+ return success();
+ }
+
+ if (!mask && fullyInBounds) {
+ // Unmasked, fully in-bounds: every lane reads unspecified contents.
+ rewriter.replaceOp(
+ op, ub::PoisonOp::create(rewriter, op.getLoc(), op.getType()));
+ return success();
+ }
+
+ // Not fully in-bounds (with or without mask): out-of-bounds lanes
+ // produce pad, and in-bounds lanes read unspecified contents from
+ // tensor.empty, so we may choose pad for those too.
+ Value rPad = op.getPadding();
+ rewriter.replaceOp(op, vector::BroadcastOp::create(rewriter, rPad.getLoc(),
+ op.getType(), rPad));
return success();
}
};
@@ -391,7 +495,7 @@
// Apply masked transfer_write + transfer_read folding to avoid spurious
// (future) roundtrips to memory.
// TODO: consider upstreaming.
- patterns.add<FoldMaskedTransferRAW>(context);
+ patterns.add<FoldTransferRAW, FoldTransferReadOfEmptyTensor>(context);
if (failed(applyPatternsGreedily(funcOp, std::move(patterns)))) {
return signalPassFailure();
}
diff --git a/compiler/src/iree/compiler/Codegen/Common/test/optimize_tensor_insert_extract_slices.mlir b/compiler/src/iree/compiler/Codegen/Common/test/optimize_tensor_insert_extract_slices.mlir
index a3d36fb..00524bd 100644
--- a/compiler/src/iree/compiler/Codegen/Common/test/optimize_tensor_insert_extract_slices.mlir
+++ b/compiler/src/iree/compiler/Codegen/Common/test/optimize_tensor_insert_extract_slices.mlir
@@ -364,3 +364,358 @@
// CHECK: vector.transpose
// CHECK: vector.transfer_write
// CHECK: }
+
+// -----
+
+// Test for FoldMaskedTransferRAW.
+// Both write and read are masked with the same mask: replace with select(mask, val, broadcast(pad)).
+func.func @fold_masked_transfer_raw_both_masked(%t: tensor<128xf16>, %mask: vector<128xi1>) -> vector<128xf16> {
+ %c0 = arith.constant 0 : index
+ %cst = arith.constant 0.0 : f16
+ %val = arith.constant dense<1.0> : vector<128xf16>
+ %w = vector.transfer_write %val, %t[%c0], %mask {in_bounds = [true]}
+ : vector<128xf16>, tensor<128xf16>
+ %r = vector.transfer_read %w[%c0], %cst, %mask {in_bounds = [true]}
+ : tensor<128xf16>, vector<128xf16>
+ return %r : vector<128xf16>
+}
+
+// CHECK-LABEL: func.func @fold_masked_transfer_raw_both_masked
+// CHECK-SAME: %[[T:[a-zA-Z0-9]+]]
+// CHECK-SAME: %[[MASK:[a-zA-Z0-9]+]]
+// CHECK-DAG: %[[CST_0:.*]] = arith.constant dense<0.000000e+00> : vector<128xf16>
+// CHECK-DAG: %[[CST_1:.*]] = arith.constant dense<1.000000e+00> : vector<128xf16>
+// CHECK: %[[SEL:.*]] = arith.select %[[MASK]], %[[CST_1]], %[[CST_0]]
+// CHECK: return %[[SEL]]
+
+// -----
+
+// Test for FoldMaskedTransferRAW.
+// Masked write, unmasked read: replace with select(wMask, val, read(original_tensor)).
+func.func @fold_masked_transfer_raw_masked_write_unmasked_read(%t: tensor<128xf16>, %mask: vector<128xi1>) -> vector<128xf16> {
+ %c0 = arith.constant 0 : index
+ %cst = arith.constant 0.0 : f16
+ %val = arith.constant dense<1.0> : vector<128xf16>
+ %w = vector.transfer_write %val, %t[%c0], %mask {in_bounds = [true]}
+ : vector<128xf16>, tensor<128xf16>
+ %r = vector.transfer_read %w[%c0], %cst {in_bounds = [true]}
+ : tensor<128xf16>, vector<128xf16>
+ return %r : vector<128xf16>
+}
+// CHECK-LABEL: func.func @fold_masked_transfer_raw_masked_write_unmasked_read
+// CHECK-SAME: %[[T:[a-zA-Z0-9]+]]
+// CHECK-SAME: %[[MASK:[a-zA-Z0-9]+]]
+// CHECK-DAG: %[[CST:.*]] = arith.constant 0.000000e+00 : f16
+// CHECK-DAG: %[[VAL:.*]] = arith.constant dense<1.000000e+00> : vector<128xf16>
+// CHECK: %[[READ:.*]] = vector.transfer_read %[[T]]{{.*}}, %[[CST]] {in_bounds = [true]}
+// CHECK-SAME: : tensor<128xf16>, vector<128xf16>
+// CHECK: %[[SEL:.*]] = arith.select %[[MASK]], %[[VAL]], %[[READ]]
+// CHECK: return %[[SEL]]
+
+// -----
+
+// Test for FoldMaskedTransferRAW.
+// Both unmasked: the read is directly replaced by the written value.
+func.func @fold_masked_transfer_raw_both_unmasked(%t: tensor<128xf16>) -> vector<128xf16> {
+ %c0 = arith.constant 0 : index
+ %cst = arith.constant 0.0 : f16
+ %val = arith.constant dense<1.0> : vector<128xf16>
+ %w = vector.transfer_write %val, %t[%c0] {in_bounds = [true]}
+ : vector<128xf16>, tensor<128xf16>
+ %r = vector.transfer_read %w[%c0], %cst {in_bounds = [true]}
+ : tensor<128xf16>, vector<128xf16>
+ return %r : vector<128xf16>
+}
+// CHECK-LABEL: func.func @fold_masked_transfer_raw_both_unmasked
+// CHECK-DAG: %[[VAL:.*]] = arith.constant dense<1.000000e+00> : vector<128xf16>
+// CHECK-NOT: vector.transfer_write
+// CHECK-NOT: vector.transfer_read
+// CHECK: return %[[VAL]]
+
+// -----
+
+// Test for FoldMaskedTransferRAW.
+// Unmasked write, masked read: re-read the original tensor with the read's mask.
+func.func @fold_masked_transfer_raw_unmasked_write_masked_read(%t: tensor<128xf16>, %mask: vector<128xi1>) -> vector<128xf16> {
+ %c0 = arith.constant 0 : index
+ %cst = arith.constant 0.0 : f16
+ %val = arith.constant dense<1.0> : vector<128xf16>
+ %w = vector.transfer_write %val, %t[%c0] {in_bounds = [true]}
+ : vector<128xf16>, tensor<128xf16>
+ %r = vector.transfer_read %w[%c0], %cst, %mask {in_bounds = [true]}
+ : tensor<128xf16>, vector<128xf16>
+ return %r : vector<128xf16>
+}
+
+// CHECK-LABEL: func.func @fold_masked_transfer_raw_unmasked_write_masked_read
+// CHECK-SAME: %[[T:[a-zA-Z0-9]+]]
+// CHECK-SAME: %[[MASK:[a-zA-Z0-9]+]]
+// CHECK-DAG: %[[VAL:.*]] = arith.constant dense<1.000000e+00> : vector<128xf16>
+// CHECK-DAG: %[[PAD:.*]] = arith.constant dense<0.000000e+00> : vector<128xf16>
+// CHECK-NOT: vector.transfer_write
+// CHECK-NOT: vector.transfer_read
+// CHECK: %[[RES:.+]] = arith.select %[[MASK]]
+// CHECK-DAG: %[[VAL]]
+// CHECK-DAG: %[[PAD]]
+// CHECK-SAME: vector<128xi1>, vector<128xf16>
+// CHECK: return %[[RES]]
+
+// -----
+
+// transfer_read from a memref (not tensor semantics): pattern must not fire.
+func.func @negative_read_empty_not_tensor_semantics(%m: memref<128xf16>) -> vector<128xf16> {
+ %c0 = arith.constant 0 : index
+ %cst = arith.constant 0.0 : f16
+ %r = vector.transfer_read %m[%c0], %cst {in_bounds = [true]}
+ : memref<128xf16>, vector<128xf16>
+ return %r : vector<128xf16>
+}
+// CHECK-LABEL: func.func @negative_read_empty_not_tensor_semantics
+// CHECK: vector.transfer_read
+
+// -----
+
+// transfer_read from a regular tensor (not tensor.empty): pattern must not fire.
+func.func @negative_read_not_empty_tensor(%t: tensor<128xf16>) -> vector<128xf16> {
+ %c0 = arith.constant 0 : index
+ %cst = arith.constant 0.0 : f16
+ %r = vector.transfer_read %t[%c0], %cst {in_bounds = [true]}
+ : tensor<128xf16>, vector<128xf16>
+ return %r : vector<128xf16>
+}
+// CHECK-LABEL: func.func @negative_read_not_empty_tensor
+// CHECK: vector.transfer_read
+
+// -----
+
+// transfer_read from tensor.empty with a transposing permutation map: bail.
+func.func @negative_read_empty_non_identity_map() -> vector<64x128xf16> {
+ %c0 = arith.constant 0 : index
+ %cst = arith.constant 0.0 : f16
+ %e = tensor.empty() : tensor<128x64xf16>
+ %r = vector.transfer_read %e[%c0, %c0], %cst
+ {in_bounds = [true, true], permutation_map = affine_map<(d0, d1) -> (d1, d0)>}
+ : tensor<128x64xf16>, vector<64x128xf16>
+ return %r : vector<64x128xf16>
+}
+// CHECK-LABEL: func.func @negative_read_empty_non_identity_map
+// CHECK: tensor.empty
+// CHECK: vector.transfer_read
+
+// -----
+
+// Unmasked, in-bounds read from tensor.empty -> ub.poison.
+func.func @fold_read_empty_unmasked_inbounds() -> vector<128xf16> {
+ %c0 = arith.constant 0 : index
+ %cst = arith.constant 0.0 : f16
+ %e = tensor.empty() : tensor<128xf16>
+ %r = vector.transfer_read %e[%c0], %cst {in_bounds = [true]}
+ : tensor<128xf16>, vector<128xf16>
+ return %r : vector<128xf16>
+}
+// CHECK-LABEL: func.func @fold_read_empty_unmasked_inbounds
+// CHECK-NOT: tensor.empty
+// CHECK-NOT: vector.transfer_read
+// CHECK: %[[POISON:.*]] = ub.poison : vector<128xf16>
+// CHECK: return %[[POISON]]
+
+// -----
+
+// Unmasked, out-of-bounds read from tensor.empty -> ub.poison.
+func.func @fold_read_empty_unmasked_outofbounds() -> vector<256xf16> {
+ %c0 = arith.constant 0 : index
+ %cst = arith.constant 0.0 : f16
+ %e = tensor.empty() : tensor<128xf16>
+ %r = vector.transfer_read %e[%c0], %cst
+ : tensor<128xf16>, vector<256xf16>
+ return %r : vector<256xf16>
+}
+// CHECK-LABEL: func.func @fold_read_empty_unmasked_outofbounds
+// CHECK-NOT: tensor.empty
+// CHECK-NOT: vector.transfer_read
+// CHECK: %[[PAD:.+]] = arith.constant dense<0.000000e+00> : vector<256xf16>
+// CHECK: return %[[PAD]]
+
+// -----
+
+// Masked read from tensor.empty where padding is ub.poison -> just ub.poison.
+func.func @fold_read_empty_masked_poison_pad(%mask: vector<128xi1>) -> vector<128xf16> {
+ %c0 = arith.constant 0 : index
+ %pad = ub.poison : f16
+ %e = tensor.empty() : tensor<128xf16>
+ %r = vector.transfer_read %e[%c0], %pad, %mask {in_bounds = [true]}
+ : tensor<128xf16>, vector<128xf16>
+ return %r : vector<128xf16>
+}
+// CHECK-LABEL: func.func @fold_read_empty_masked_poison_pad
+// CHECK-NOT: tensor.empty
+// CHECK-NOT: vector.transfer_read
+// CHECK-NOT: arith.select
+// CHECK: %[[POISON:.*]] = ub.poison : vector<128xf16>
+// CHECK: return %[[POISON]]
+
+// -----
+
+// Masked read from tensor.empty with a concrete pad value -> select(mask, poison, broadcast(pad)).
+// Followed by: select cond, poison, X -> X
+func.func @fold_read_empty_masked_real_pad(%mask: vector<128xi1>) -> vector<128xf16> {
+ %c0 = arith.constant 0 : index
+ %cst = arith.constant 0.0 : f16
+ %e = tensor.empty() : tensor<128xf16>
+ %r = vector.transfer_read %e[%c0], %cst, %mask {in_bounds = [true]}
+ : tensor<128xf16>, vector<128xf16>
+ return %r : vector<128xf16>
+}
+// CHECK-LABEL: func.func @fold_read_empty_masked_real_pad
+// CHECK: %[[CST:.*]] = arith.constant dense<0.000000e+00> : vector<128xf16>
+// CHECK: return %[[CST]]
+
+// -----
+
+// Unmasked read from a dynamically-shaped tensor.empty -> ub.poison.
+func.func @fold_read_empty_dynamic_unmasked(%sz: index) -> vector<128xf16> {
+ %c0 = arith.constant 0 : index
+ %cst = arith.constant 0.0 : f16
+ %e = tensor.empty(%sz) : tensor<?xf16>
+ %r = vector.transfer_read %e[%c0], %cst {in_bounds = [true]}
+ : tensor<?xf16>, vector<128xf16>
+ return %r : vector<128xf16>
+}
+// CHECK-LABEL: func.func @fold_read_empty_dynamic_unmasked
+// CHECK-NOT: tensor.empty
+// CHECK-NOT: vector.transfer_read
+// CHECK: %[[POISON:.*]] = ub.poison : vector<128xf16>
+// CHECK: return %[[POISON]]
+
+// -----
+
+// Multiple chained gathers: the index vectors computed for the first two
+// gathers (and the clamp ops derived from their results) must be reused
+// directly as vector SSA values by subsequent gathers. Vectorization may
+// introduce tensor.empty<...xindex> intermediaries with write-read chains
+// for materialized index vectors; these are cleaned up by the
+// optimize-tensor-insert-extract-slices pass that follows.
+
+#map = affine_map<(d0, d1, d2)[s0, s1, s2] -> (s0, s1, s2, 0)>
+#map1 = affine_map<(d0, d1, d2)[s0, s1, s2] -> ()>
+#map2 = affine_map<(d0, d1, d2)[s0, s1, s2] -> (d0, d1, d2)>
+#map3 = affine_map<(d0, d1, d2)[s0, s1, s2] -> (d2)>
+#map4 = affine_map<(d0, d1, d2)[s0, s1, s2] -> (s0, s1, s2)>
+module {
+ func.func @three_gathers_index_materialization(%arg0: tensor<1x8x?xf32>, %arg1: tensor<1x8xf32>, %arg2: tensor<1x8xf32>, %arg3: tensor<1x8x?xf32>, %arg4: tensor<50x32x25x2xi32>, %arg5: tensor<50x40x40xi8>, %arg6: index, %arg7: index, %arg8: index, %arg9: index) -> tensor<1x8x?xf32> {
+ %cst = arith.constant dense<0> : vector<1x8x8xi8>
+ %0 = ub.poison : i8
+ %cst_0 = arith.constant dense<39> : vector<1x8x8xindex>
+ %cst_1 = arith.constant dense<0> : vector<1x8x8xindex>
+ %1 = ub.poison : i32
+ %cst_2 = arith.constant dense<0.000000e+00> : vector<1x8x8xf32>
+ %2 = ub.poison : index
+ %3 = ub.poison : f32
+ %c2 = arith.constant 2 : index
+ %c8 = arith.constant 8 : index
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %4 = tensor.empty() : tensor<1x8x8xf32>
+ %5 = tensor.empty() : tensor<1x8x8xindex>
+ %6 = tensor.empty() : tensor<1x8x8xindex>
+ %7 = tensor.empty() : tensor<1x8x8xindex>
+ %dim = tensor.dim %arg0, %c2 : tensor<1x8x?xf32>
+ %8 = vector.transfer_read %arg1[%c0, %c0], %3 {in_bounds = [true, true]} : tensor<1x8xf32>, vector<1x8xf32>
+ %9 = vector.transfer_read %arg2[%c0, %c0], %3 {in_bounds = [true, true]} : tensor<1x8xf32>, vector<1x8xf32>
+ %10 = vector.create_mask %c1, %c8, %dim : vector<1x8x8xi1>
+ %11 = vector.transfer_read %arg0[%c0, %c0, %c0], %3, %10 {in_bounds = [true, true, true]} : tensor<1x8x?xf32>, vector<1x8x8xf32>
+ %12 = arith.divf %8, %9 : vector<1x8xf32>
+ %13 = vector.broadcast %12 : vector<1x8xf32> to vector<8x1x8xf32>
+ %14 = vector.transpose %13, [1, 2, 0] : vector<8x1x8xf32> to vector<1x8x8xf32>
+ %15 = arith.cmpf une, %11, %cst_2 : vector<1x8x8xf32>
+ %16 = arith.select %15, %14, %cst_2 : vector<1x8x8xi1>, vector<1x8x8xf32>
+ %17 = arith.addi %arg6, %arg7 : index
+ %18 = vector.step : vector<8xindex>
+ %19 = vector.broadcast %18 : vector<8xindex> to vector<1x8x8xindex>
+ %20 = vector.transpose %19, [0, 2, 1] : vector<1x8x8xindex> to vector<1x8x8xindex>
+ %21 = vector.broadcast %arg8 : index to vector<1x8x8xindex>
+ %22 = arith.addi %21, %20 : vector<1x8x8xindex>
+ %23 = vector.step : vector<8xindex>
+ %24 = vector.broadcast %arg9 : index to vector<8xindex>
+ %25 = arith.addi %24, %23 : vector<8xindex>
+ %26 = vector.transfer_write %16, %4[%c0, %c0, %c0], %10 {in_bounds = [true, true, true]} : vector<1x8x8xf32>, tensor<1x8x8xf32>
+ %27 = vector.broadcast %17 : index to vector<1x8x8xindex>
+ %28 = vector.transfer_write %27, %5[%c0, %c0, %c0], %10 {in_bounds = [true, true, true]} : vector<1x8x8xindex>, tensor<1x8x8xindex>
+ %29 = vector.transfer_write %22, %6[%c0, %c0, %c0], %10 {in_bounds = [true, true, true]} : vector<1x8x8xindex>, tensor<1x8x8xindex>
+ %30 = vector.broadcast %25 : vector<8xindex> to vector<1x8x8xindex>
+ %31 = vector.transfer_write %30, %7[%c0, %c0, %c0], %10 {in_bounds = [true, true, true]} : vector<1x8x8xindex>, tensor<1x8x8xindex>
+ %32 = iree_vector_ext.transfer_gather %arg4[%c0, %c0, %c0, %c0] [%17, %22, %25 : index, vector<1x8x8xindex>, vector<8xindex>], %1 {indexing_maps = [#map, #map1, #map2, #map3]} : tensor<50x32x25x2xi32>, vector<1x8x8xi32>
+ %33 = arith.index_cast %32 : vector<1x8x8xi32> to vector<1x8x8xindex>
+ %34 = vector.transfer_read %28[%c0, %c0, %c0], %2 {in_bounds = [true, true, true]} : tensor<1x8x8xindex>, vector<1x8x8xindex>
+ %35 = vector.transfer_read %29[%c0, %c0, %c0], %2 {in_bounds = [true, true, true]} : tensor<1x8x8xindex>, vector<1x8x8xindex>
+ %36 = vector.transfer_read %31[%c0, %c0, %c0], %2 {in_bounds = [true, true, true]} : tensor<1x8x8xindex>, vector<1x8x8xindex>
+ %37 = iree_vector_ext.transfer_gather %arg4[%c0, %c0, %c0, %c1] [%34, %35, %36 : vector<1x8x8xindex>, vector<1x8x8xindex>, vector<1x8x8xindex>], %1 {indexing_maps = [#map, #map2, #map2, #map2]} : tensor<50x32x25x2xi32>, vector<1x8x8xi32>
+ %38 = arith.index_cast %37 : vector<1x8x8xi32> to vector<1x8x8xindex>
+ %39 = arith.maxsi %33, %cst_1 : vector<1x8x8xindex>
+ %40 = arith.minui %39, %cst_0 : vector<1x8x8xindex>
+ %41 = arith.maxsi %38, %cst_1 : vector<1x8x8xindex>
+ %42 = arith.minui %41, %cst_0 : vector<1x8x8xindex>
+ %43 = vector.transfer_read %28[%c0, %c0, %c0], %2 {in_bounds = [true, true, true]} : tensor<1x8x8xindex>, vector<1x8x8xindex>
+ %44 = iree_vector_ext.transfer_gather %arg5[%c0, %c0, %c0] [%43, %40, %42 : vector<1x8x8xindex>, vector<1x8x8xindex>, vector<1x8x8xindex>], %0 {indexing_maps = [#map4, #map2, #map2, #map2]} : tensor<50x40x40xi8>, vector<1x8x8xi8>
+ %45 = vector.transfer_read %26[%c0, %c0, %c0], %3 {in_bounds = [true, true, true]} : tensor<1x8x8xf32>, vector<1x8x8xf32>
+ %46 = arith.cmpi ugt, %44, %cst : vector<1x8x8xi8>
+ %47 = arith.select %46, %45, %cst_2 : vector<1x8x8xi1>, vector<1x8x8xf32>
+ %48 = vector.transfer_write %47, %arg3[%c0, %c0, %c0] {in_bounds = [true, true, false]} : vector<1x8x8xf32>, tensor<1x8x?xf32>
+ return %48 : tensor<1x8x?xf32>
+ }
+}
+
+// Verify three transfer_gather ops are produced. Index vectors from the first
+// two gathers (and the clamp ops on their results) feed directly into the
+// third gather as vector SSA values — no tensor.empty<...xindex> or
+// write-read chains.
+//
+// CHECK-LABEL: func.func @three_gathers_index_materialization
+// CHECK-SAME: %[[IN0:[a-zA-Z0-9]+]]: tensor<1x8x?xf32>
+// CHECK-SAME: %[[IN1:[a-zA-Z0-9]+]]: tensor<1x8xf32>
+// CHECK-SAME: %[[IN2:[a-zA-Z0-9]+]]: tensor<1x8xf32>
+// CHECK-SAME: %[[OUT:[a-zA-Z0-9]+]]: tensor<1x8x?xf32>
+// CHECK-SAME: %[[TABLE:[a-zA-Z0-9]+]]: tensor<50x32x25x2xi32>
+// CHECK-SAME: %[[LUT:[a-zA-Z0-9]+]]: tensor<50x40x40xi8>
+//
+// No index-typed tensors should appear anywhere in the output.
+// CHECK-NOT: tensor<{{.*}}xindex>
+//
+// Gather #1 from %indir_table — index vecs consumed directly.
+// CHECK: %[[G1:.+]] = iree_vector_ext.transfer_gather %[[TABLE]]
+// CHECK-SAME: : tensor<50x32x25x2xi32>, vector<1x8x8xi32>
+// CHECK-NOT: tensor<{{.*}}xindex>
+// CHECK: %[[G1_IDX:.+]] = arith.index_cast %[[G1]] : vector<1x8x8xi32> to vector<1x8x8xindex>
+// CHECK-NOT: tensor<{{.*}}xindex>
+//
+// Gather #2 from %indir_table — reuses the same index vecs as #1.
+// CHECK: %[[G2:.+]] = iree_vector_ext.transfer_gather %[[TABLE]]
+// CHECK-SAME: : tensor<50x32x25x2xi32>, vector<1x8x8xi32>
+// CHECK-NOT: tensor<{{.*}}xindex>
+// CHECK: %[[G2_IDX:.+]] = arith.index_cast %[[G2]] : vector<1x8x8xi32> to vector<1x8x8xindex>
+// CHECK-NOT: tensor<{{.*}}xindex>
+//
+// Clamp results of gather #1 and #2 — pure vector ops, no tensor roundtrip.
+// G1_IDX -> maxsi -> minui = CLAMP_A, G2_IDX -> maxsi -> minui = CLAMP_B.
+// CHECK: %[[G1_MAX:.+]] = arith.maxsi %[[G1_IDX]], {{.*}} : vector<1x8x8xindex>
+// CHECK-NOT: tensor<{{.*}}xindex>
+// CHECK: %[[CLAMP_A:.+]] = arith.minui %[[G1_MAX]], {{.*}} : vector<1x8x8xindex>
+// CHECK-NOT: tensor<{{.*}}xindex>
+// CHECK: %[[G2_MAX:.+]] = arith.maxsi %[[G2_IDX]], {{.*}} : vector<1x8x8xindex>
+// CHECK-NOT: tensor<{{.*}}xindex>
+// CHECK: %[[CLAMP_B:.+]] = arith.minui %[[G2_MAX]], {{.*}} : vector<1x8x8xindex>
+// CHECK-NOT: tensor<{{.*}}xindex>
+//
+// Gather #3 from %lut — takes clamped results directly as index vectors.
+// CHECK: %[[G3:.+]] = iree_vector_ext.transfer_gather %[[LUT]]
+// CHECK-SAME: [{{.*}}, %[[CLAMP_A]], %[[CLAMP_B]] : {{.*}}]
+// CHECK-SAME: : tensor<50x40x40xi8>, vector<1x8x8xi8>
+// CHECK-NOT: tensor<{{.*}}xindex>
+//
+// Final select + write.
+// CHECK: %[[GATE:.+]] = arith.cmpi ugt, %[[G3]]
+// CHECK-NOT: tensor<{{.*}}xindex>
+// CHECK: %[[RES:.+]] = arith.select %[[GATE]]
+// CHECK-NOT: tensor<{{.*}}xindex>
+// CHECK: vector.transfer_write %[[RES]], %[[OUT]]
+// CHECK-NOT: tensor<{{.*}}xindex>
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/pipeline_vector_distribute_reduction_gfx942.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/pipeline_vector_distribute_reduction_gfx942.mlir
index 4b04ef0..c84f034 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/pipeline_vector_distribute_reduction_gfx942.mlir
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/pipeline_vector_distribute_reduction_gfx942.mlir
@@ -289,6 +289,11 @@
%6 = iree_tensor_ext.dispatch.tensor.load %2, offsets = [0, 0, 0], sizes = [20, 4096, 64], strides = [1, 1, 1] : !iree_tensor_ext.dispatch.tensor<readonly:tensor<20x4096x64xf16>> -> tensor<20x4096x64xf16>
%7 = tensor.empty() : tensor<20x1x64xf32>
%8 = tensor.empty() : tensor<20x1xf32>
+ %cst_zero = arith.constant 0.000000e+00 : f32
+ %cst_neg_inf = arith.constant 0xFF800000 : f32
+ %acc_fill = linalg.fill ins(%cst_zero : f32) outs(%7: tensor<20x1x64xf32>) -> tensor<20x1x64xf32>
+ %max_fill = linalg.fill ins(%cst_neg_inf : f32) outs(%8: tensor<20x1xf32>) -> tensor<20x1xf32>
+ %sum_fill = linalg.fill ins(%cst_zero : f32) outs(%8: tensor<20x1xf32>) -> tensor<20x1xf32>
%9:3 = iree_linalg_ext.online_attention {indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>,
affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d2)>,
affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d4)>,
@@ -301,7 +306,7 @@
qk_attrs = {lowering_config = #qk_config},
pv_attrs = {lowering_config = #pv_config}
}}
- ins(%4, %5, %6, %cst : tensor<20x1x64xf16>, tensor<20x4096x64xf16>, tensor<20x4096x64xf16>, f16) outs(%7, %8, %8 : tensor<20x1x64xf32>, tensor<20x1xf32>, tensor<20x1xf32>) {
+ ins(%4, %5, %6, %cst : tensor<20x1x64xf16>, tensor<20x4096x64xf16>, tensor<20x4096x64xf16>, f16) outs(%acc_fill, %max_fill, %sum_fill : tensor<20x1x64xf32>, tensor<20x1xf32>, tensor<20x1xf32>) {
^bb0(%score: f32):
iree_linalg_ext.yield %score : f32
} -> tensor<20x1x64xf32>, tensor<20x1xf32>, tensor<20x1xf32>
@@ -464,6 +469,13 @@
%13 = tensor.empty(%6) : tensor<4x1x1x?x32xf16>
%14 = tensor.empty(%6) : tensor<4x?x1x32x128xf16>
%15 = tensor.empty() : tensor<4x1x1xf32>
+
+ %cst_acc = arith.constant 0.000000e+00 : f32
+ %cst_max = arith.constant 0xFF800000 : f32
+ %acc_fill = linalg.fill ins(%cst_acc : f32) outs(%12 : tensor<4x1x1x128xf32>) -> tensor<4x1x1x128xf32>
+ %max_fill = linalg.fill ins(%cst_max : f32) outs(%15 : tensor<4x1x1xf32>) -> tensor<4x1x1xf32>
+ %sum_fill = linalg.fill ins(%cst_acc : f32) outs(%15 : tensor<4x1x1xf32>) -> tensor<4x1x1xf32>
+
%16 = iree_tensor_ext.dispatch.tensor.load %3, offsets = [0, 0, 0, 0, 0, 0], sizes = [4096, 1, 1, 1, 32, 128], strides = [1, 1, 1, 1, 1, 1] : !iree_tensor_ext.dispatch.tensor<readonly:tensor<4096x1x2x1x32x128xf16>> -> tensor<4096x1x32x128xf16>
%17 = iree_linalg_ext.gather dimension_map = [0] ins(%16, %9 : tensor<4096x1x32x128xf16>, tensor<4x?xi64>) outs(%14 : tensor<4x?x1x32x128xf16>) -> tensor<4x?x1x32x128xf16>
%18 = iree_linalg_ext.gather dimension_map = [0] ins(%16, %10 : tensor<4096x1x32x128xf16>, tensor<4x?xi64>) outs(%14 : tensor<4x?x1x32x128xf16>) -> tensor<4x?x1x32x128xf16>
@@ -500,7 +512,7 @@
lowering_config = #attention_lowering_config
}
ins(%11, %17, %18, %cst, %19 : tensor<4x1x1x128xf16>, tensor<4x?x1x32x128xf16>, tensor<4x?x1x32x128xf16>, f16, tensor<4x1x1x?x32xf16>)
- outs(%12, %15, %15 : tensor<4x1x1x128xf32>, tensor<4x1x1xf32>, tensor<4x1x1xf32>) {
+ outs(%acc_fill, %max_fill, %sum_fill : tensor<4x1x1x128xf32>, tensor<4x1x1xf32>, tensor<4x1x1xf32>) {
^bb0(%arg0: f32):
iree_linalg_ext.yield %arg0 : f32
} -> tensor<4x1x1x128xf32>, tensor<4x1x1xf32>, tensor<4x1x1xf32>