Generalize FoldMaskedTransferRaw and add FoldTransferReadOfEmptyTensor (#24301)

* Generalizes FoldMaskedTransferRAW from (masked, masked) to (unmasked,
unmasked), (unmasked, masked), (masked, unmasked).
* Adds pattern to fold transfer_read(tensor.empty)) -> ub.poison
* This allows intermediary index tensors to be folded after
vectorization.
* The test pipeline_vector_distribute_reduction_gfx942.mlir needed to be
corrected. The empty tensors are now folded, but the test was wrong,
online_attention should not have had empty tensors as operands to begin
with. All passes that create online_attention fill the operands with
either 0 or -1. So we do that here as well.

Fixes #24294

----

Assisted-By: Claude Opus 4.6
diff --git a/compiler/src/iree/compiler/Codegen/Common/OptimizeTensorInsertExtractSlices.cpp b/compiler/src/iree/compiler/Codegen/Common/OptimizeTensorInsertExtractSlices.cpp
index 0eabebe..ca3bf4a 100644
--- a/compiler/src/iree/compiler/Codegen/Common/OptimizeTensorInsertExtractSlices.cpp
+++ b/compiler/src/iree/compiler/Codegen/Common/OptimizeTensorInsertExtractSlices.cpp
@@ -256,77 +256,181 @@
   }
 };
 
-/// Folds IR resembling:
-/// ```
-///   %20 = vector.transfer_write %19, %16[%c0], %17 {in_bounds = [true]}
-//      : vector<128xf16>, tensor<?xf16>
-//    %21 = vector.transfer_read %20[%c0], %cst_2, %17
-///     : tensor<?xf16>, vector<128xf16>
-/// ```
-/// into a simpler masked vector.transfer_read.
+/// Folds a transfer_read that reads from the result of a transfer_write on
+/// the same region (Read-After-Write) into arithmetic on the written value,
+/// the original tensor, the masks, and the read's padding.
+///
+/// The general semantics are:
+///
+///   written_tensor[i] = wMask[i] ? valToStore[i] : original[i]
+///   result[i]         = rMask[i] ? written_tensor[i] : rPad
+///
+/// Which gives:
+///   result = select(rMask, select(wMask, valToStore, original),
+///   broadcast(rPad))
+///
+/// Special cases avoid emitting unnecessary IR:
+///   - No wMask (unmasked write): wMask is implicitly all-true, inner select
+///     collapses to valToStore.
+///   - No rMask (unmasked read): rMask is implicitly all-true, outer select
+///     collapses away.
+///   - wMask == rMask: the original tensor is never needed (anywhere rMask is
+///     true, wMask is also true), so the inner select collapses to valToStore.
+///
 /// After bufferization, this generally removes the need for materializing the
 /// write to memory.
 // TODO: Consider upstreaming
-struct FoldMaskedTransferRAW : OpRewritePattern<vector::TransferReadOp> {
+struct FoldTransferRAW : OpRewritePattern<vector::TransferReadOp> {
   using Base::Base;
 
-  LogicalResult matchAndRewrite(vector::TransferReadOp op,
+  LogicalResult matchAndRewrite(vector::TransferReadOp readOp,
                                 PatternRewriter &rewriter) const override {
-    // Fail to match if the read doesn't have pure tensor semantics.
-    if (!op.hasPureTensorSemantics()) {
+    if (!readOp.hasPureTensorSemantics()) {
       return failure();
     }
 
-    // Try to get the producing write op.
     auto writeOp = dyn_cast_if_present<vector::TransferWriteOp>(
-        op.getBase().getDefiningOp());
-    // Fail to match if the write doesn't have pure tensor semantics.
+        readOp.getBase().getDefiningOp());
     if (!writeOp || !writeOp.hasPureTensorSemantics()) {
       return failure();
     }
 
     Value valToStore = writeOp.getValueToStore();
-    // Fail to match if the in/out types are different
-    if (valToStore.getType() != op.getType()) {
+    if (valToStore.getType() != readOp.getType()) {
       return failure();
     }
 
     // Work only with trivial or equal indices.
-    if ((llvm::any_of(op.getIndices(),
+    if ((llvm::any_of(readOp.getIndices(),
                       [](Value v) { return !isZeroInteger(v); }) ||
          llvm::any_of(writeOp.getIndices(),
                       [](Value v) { return !isZeroInteger(v); })) &&
-        (op.getIndices() != writeOp.getIndices())) {
+        (readOp.getIndices() != writeOp.getIndices())) {
       return failure();
     }
 
-    // Work only with minor identity mappings.
-    if (!op.getPermutationMap().isMinorIdentity() ||
+    if (!readOp.getPermutationMap().isMinorIdentity() ||
         !writeOp.getPermutationMap().isMinorIdentity()) {
       return failure();
     }
 
     TypedValue<VectorType> wMask = writeOp.getMask();
-    Value rPad = op.getPadding();
+    TypedValue<VectorType> rMask = readOp.getMask();
 
-    // Match only if the write and read op are masked and have the same mask.
-    if (!wMask || (wMask != op.getMask())) {
+    // Build the inner value: select(wMask, valToStore, original).
+    // When wMask is absent (unmasked write) or wMask == rMask (original is
+    // never accessed), this simplifies to just valToStore.
+    Value inner = valToStore;
+    bool needsOriginal = wMask && wMask != rMask;
+    if (needsOriginal) {
+      Value originalRead = vector::TransferReadOp::create(
+          rewriter, readOp.getLoc(), readOp.getType(), writeOp.getBase(),
+          readOp.getIndices(), readOp.getPermutationMap(), readOp.getPadding(),
+          /*mask=*/Value(), readOp.getInBoundsAttr());
+      inner = arith::SelectOp::create(rewriter, readOp.getLoc(), wMask,
+                                      valToStore, originalRead);
+    }
+
+    if (!rMask) {
+      rewriter.replaceOp(readOp, inner);
+      return success();
+    }
+
+    // Build the outer value: select(rMask, inner, broadcast(rPad)).
+    // When rMask is absent (unmasked read), the result is just inner.
+    Value rPad = readOp.getPadding();
+    assert(!isa<VectorType>(rPad.getType()) &&
+           "masked transfers on vector element types are not supported; see "
+           "verifyTransferOp in upstream MLIR VectorOps.cpp");
+    Value padVal = vector::BroadcastOp::create(rewriter, rPad.getLoc(),
+                                               valToStore.getType(), rPad);
+    rewriter.replaceOpWithNewOp<arith::SelectOp>(readOp, rMask, inner, padVal);
+    return success();
+  }
+};
+
+/// Folds transfer_read(tensor.empty).
+///
+/// Since tensor.empty has unspecified contents, reading from it produces
+/// an unspecified value, which is exactly the semantics of ub.poison.
+/// Out of bounds means that pad is used.
+///
+///   Case 1 — fully in-bounds, no mask:
+///     %e = tensor.empty() : tensor<128xf16>
+///     %r = vector.transfer_read %e[%c0], %pad {in_bounds = [true]}
+///   ->
+///     %r = ub.poison : vector<128xf16>
+///
+///   Case 2 — fully in-bounds, masked:
+///     %e = tensor.empty() : tensor<128xf16>
+///     %r = vector.transfer_read %e[%c0], %pad, %mask {in_bounds = [true]}
+///   ->
+///     %poison = ub.poison : vector<128xf16>
+///     %bcast  = vector.broadcast %pad : f16 to vector<128xf16>
+///     %r = arith.select %mask, %poison, %bcast
+///
+///   Case 3 — not fully in-bounds, no mask:
+///     %e = tensor.empty() : tensor<100xf16>
+///     %r = vector.transfer_read %e[%c0], %pad
+///                              : tensor<100xf16>, vector<128xf16>
+///   ->
+///     %r = vector.broadcast %pad : f16 to vector<128xf16>
+///
+///   Case 4 — not fully in-bounds, masked:
+///     %e = tensor.empty() : tensor<100xf16>
+///     %r = vector.transfer_read %e[%c0], %pad, %mask
+///                              : tensor<100xf16>, vector<128xf16>
+///   ->
+///     %r = vector.broadcast %pad : f16 to vector<128xf16>
+struct FoldTransferReadOfEmptyTensor
+    : OpRewritePattern<vector::TransferReadOp> {
+  using Base::Base;
+
+  LogicalResult matchAndRewrite(vector::TransferReadOp op,
+                                PatternRewriter &rewriter) const override {
+    if (!op.hasPureTensorSemantics()) {
       return failure();
     }
 
-    // NOTE[FoldMaskedTransferRAW]: since masking is not supported on shaped
-    // types with vector element types (see `verifyTransferOp` in upstream MLIR
-    // VectorOps.cpp), and the write op has a mask, it can be assumed `rPad`
-    // never has a vector type. But for sanity add an assert in case things
-    // change upstream.
-    assert(!isa<VectorType>(rPad.getType()) &&
-           "search `NOTE[FoldMaskedTransferRAW]` in "
-           "GenericVectorization.cpp::FoldMaskedTransferRAW for information");
+    if (!op.getBase().getDefiningOp<tensor::EmptyOp>()) {
+      return failure();
+    }
 
-    // Materialize the padding with a constant.
-    auto padVal = vector::BroadcastOp::create(rewriter, rPad.getLoc(),
-                                              valToStore.getType(), rPad);
-    rewriter.replaceOpWithNewOp<arith::SelectOp>(op, wMask, valToStore, padVal);
+    if (!op.getPermutationMap().isMinorIdentity()) {
+      return failure();
+    }
+
+    bool fullyInBounds =
+        llvm::all_of(op.getInBoundsValues(), [](bool v) { return v; });
+    TypedValue<VectorType> mask = op.getMask();
+
+    if (mask && fullyInBounds) {
+      // Masked, fully in-bounds: mask-on lanes read unspecified contents
+      // (poison), mask-off lanes produce the padding value.
+      Value rPad = op.getPadding();
+      assert(!isa<VectorType>(rPad.getType()) &&
+             "masked transfers on vector element types are not supported; "
+             "see verifyTransferOp in upstream MLIR VectorOps.cpp");
+      Value poison = ub::PoisonOp::create(rewriter, op.getLoc(), op.getType());
+      Value padVal = vector::BroadcastOp::create(rewriter, rPad.getLoc(),
+                                                 op.getType(), rPad);
+      rewriter.replaceOpWithNewOp<arith::SelectOp>(op, mask, poison, padVal);
+      return success();
+    }
+
+    if (!mask && fullyInBounds) {
+      // Unmasked, fully in-bounds: every lane reads unspecified contents.
+      rewriter.replaceOp(
+          op, ub::PoisonOp::create(rewriter, op.getLoc(), op.getType()));
+      return success();
+    }
+
+    // Not fully in-bounds (with or without mask): out-of-bounds lanes
+    // produce pad, and in-bounds lanes read unspecified contents from
+    // tensor.empty, so we may choose pad for those too.
+    Value rPad = op.getPadding();
+    rewriter.replaceOp(op, vector::BroadcastOp::create(rewriter, rPad.getLoc(),
+                                                       op.getType(), rPad));
     return success();
   }
 };
@@ -391,7 +495,7 @@
   // Apply masked transfer_write + transfer_read folding to avoid spurious
   // (future) roundtrips to memory.
   // TODO: consider upstreaming.
-  patterns.add<FoldMaskedTransferRAW>(context);
+  patterns.add<FoldTransferRAW, FoldTransferReadOfEmptyTensor>(context);
   if (failed(applyPatternsGreedily(funcOp, std::move(patterns)))) {
     return signalPassFailure();
   }
diff --git a/compiler/src/iree/compiler/Codegen/Common/test/optimize_tensor_insert_extract_slices.mlir b/compiler/src/iree/compiler/Codegen/Common/test/optimize_tensor_insert_extract_slices.mlir
index a3d36fb..00524bd 100644
--- a/compiler/src/iree/compiler/Codegen/Common/test/optimize_tensor_insert_extract_slices.mlir
+++ b/compiler/src/iree/compiler/Codegen/Common/test/optimize_tensor_insert_extract_slices.mlir
@@ -364,3 +364,358 @@
 // CHECK:         vector.transpose
 // CHECK:         vector.transfer_write
 // CHECK:       }
+
+// -----
+
+// Test for FoldMaskedTransferRAW.
+// Both write and read are masked with the same mask: replace with select(mask, val, broadcast(pad)).
+func.func @fold_masked_transfer_raw_both_masked(%t: tensor<128xf16>, %mask: vector<128xi1>) -> vector<128xf16> {
+  %c0 = arith.constant 0 : index
+  %cst = arith.constant 0.0 : f16
+  %val = arith.constant dense<1.0> : vector<128xf16>
+  %w = vector.transfer_write %val, %t[%c0], %mask {in_bounds = [true]}
+     : vector<128xf16>, tensor<128xf16>
+  %r = vector.transfer_read %w[%c0], %cst, %mask {in_bounds = [true]}
+     : tensor<128xf16>, vector<128xf16>
+  return %r : vector<128xf16>
+}
+
+// CHECK-LABEL: func.func @fold_masked_transfer_raw_both_masked
+// CHECK-SAME:    %[[T:[a-zA-Z0-9]+]]
+// CHECK-SAME:    %[[MASK:[a-zA-Z0-9]+]]
+// CHECK-DAG:     %[[CST_0:.*]] = arith.constant dense<0.000000e+00> : vector<128xf16>
+// CHECK-DAG:     %[[CST_1:.*]] = arith.constant dense<1.000000e+00> : vector<128xf16>
+// CHECK:         %[[SEL:.*]] = arith.select %[[MASK]], %[[CST_1]], %[[CST_0]]
+// CHECK:         return %[[SEL]]
+
+// -----
+
+// Test for FoldMaskedTransferRAW.
+// Masked write, unmasked read: replace with select(wMask, val, read(original_tensor)).
+func.func @fold_masked_transfer_raw_masked_write_unmasked_read(%t: tensor<128xf16>, %mask: vector<128xi1>) -> vector<128xf16> {
+  %c0 = arith.constant 0 : index
+  %cst = arith.constant 0.0 : f16
+  %val = arith.constant dense<1.0> : vector<128xf16>
+  %w = vector.transfer_write %val, %t[%c0], %mask {in_bounds = [true]}
+     : vector<128xf16>, tensor<128xf16>
+  %r = vector.transfer_read %w[%c0], %cst {in_bounds = [true]}
+     : tensor<128xf16>, vector<128xf16>
+  return %r : vector<128xf16>
+}
+// CHECK-LABEL: func.func @fold_masked_transfer_raw_masked_write_unmasked_read
+// CHECK-SAME:    %[[T:[a-zA-Z0-9]+]]
+// CHECK-SAME:    %[[MASK:[a-zA-Z0-9]+]]
+// CHECK-DAG:     %[[CST:.*]] = arith.constant 0.000000e+00 : f16
+// CHECK-DAG:     %[[VAL:.*]] = arith.constant dense<1.000000e+00> : vector<128xf16>
+// CHECK:         %[[READ:.*]] = vector.transfer_read %[[T]]{{.*}}, %[[CST]] {in_bounds = [true]}
+// CHECK-SAME:      : tensor<128xf16>, vector<128xf16>
+// CHECK:         %[[SEL:.*]] = arith.select %[[MASK]], %[[VAL]], %[[READ]]
+// CHECK:         return %[[SEL]]
+
+// -----
+
+// Test for FoldMaskedTransferRAW.
+// Both unmasked: the read is directly replaced by the written value.
+func.func @fold_masked_transfer_raw_both_unmasked(%t: tensor<128xf16>) -> vector<128xf16> {
+  %c0 = arith.constant 0 : index
+  %cst = arith.constant 0.0 : f16
+  %val = arith.constant dense<1.0> : vector<128xf16>
+  %w = vector.transfer_write %val, %t[%c0] {in_bounds = [true]}
+     : vector<128xf16>, tensor<128xf16>
+  %r = vector.transfer_read %w[%c0], %cst {in_bounds = [true]}
+     : tensor<128xf16>, vector<128xf16>
+  return %r : vector<128xf16>
+}
+// CHECK-LABEL: func.func @fold_masked_transfer_raw_both_unmasked
+// CHECK-DAG:     %[[VAL:.*]] = arith.constant dense<1.000000e+00> : vector<128xf16>
+// CHECK-NOT:     vector.transfer_write
+// CHECK-NOT:     vector.transfer_read
+// CHECK:         return %[[VAL]]
+
+// -----
+
+// Test for FoldMaskedTransferRAW.
+// Unmasked write, masked read: re-read the original tensor with the read's mask.
+func.func @fold_masked_transfer_raw_unmasked_write_masked_read(%t: tensor<128xf16>, %mask: vector<128xi1>) -> vector<128xf16> {
+  %c0 = arith.constant 0 : index
+  %cst = arith.constant 0.0 : f16
+  %val = arith.constant dense<1.0> : vector<128xf16>
+  %w = vector.transfer_write %val, %t[%c0] {in_bounds = [true]}
+     : vector<128xf16>, tensor<128xf16>
+  %r = vector.transfer_read %w[%c0], %cst, %mask {in_bounds = [true]}
+     : tensor<128xf16>, vector<128xf16>
+  return %r : vector<128xf16>
+}
+
+// CHECK-LABEL: func.func @fold_masked_transfer_raw_unmasked_write_masked_read
+// CHECK-SAME:    %[[T:[a-zA-Z0-9]+]]
+// CHECK-SAME:    %[[MASK:[a-zA-Z0-9]+]]
+// CHECK-DAG:     %[[VAL:.*]] = arith.constant dense<1.000000e+00> : vector<128xf16>
+// CHECK-DAG:     %[[PAD:.*]] = arith.constant dense<0.000000e+00> : vector<128xf16>
+// CHECK-NOT:     vector.transfer_write
+// CHECK-NOT:     vector.transfer_read
+// CHECK:         %[[RES:.+]] = arith.select %[[MASK]]
+// CHECK-DAG:         %[[VAL]]
+// CHECK-DAG:         %[[PAD]]
+// CHECK-SAME:       vector<128xi1>, vector<128xf16>
+// CHECK:         return %[[RES]]
+
+// -----
+
+// transfer_read from a memref (not tensor semantics): pattern must not fire.
+func.func @negative_read_empty_not_tensor_semantics(%m: memref<128xf16>) -> vector<128xf16> {
+  %c0 = arith.constant 0 : index
+  %cst = arith.constant 0.0 : f16
+  %r = vector.transfer_read %m[%c0], %cst {in_bounds = [true]}
+     : memref<128xf16>, vector<128xf16>
+  return %r : vector<128xf16>
+}
+// CHECK-LABEL: func.func @negative_read_empty_not_tensor_semantics
+// CHECK:         vector.transfer_read
+
+// -----
+
+// transfer_read from a regular tensor (not tensor.empty): pattern must not fire.
+func.func @negative_read_not_empty_tensor(%t: tensor<128xf16>) -> vector<128xf16> {
+  %c0 = arith.constant 0 : index
+  %cst = arith.constant 0.0 : f16
+  %r = vector.transfer_read %t[%c0], %cst {in_bounds = [true]}
+     : tensor<128xf16>, vector<128xf16>
+  return %r : vector<128xf16>
+}
+// CHECK-LABEL: func.func @negative_read_not_empty_tensor
+// CHECK:         vector.transfer_read
+
+// -----
+
+// transfer_read from tensor.empty with a transposing permutation map: bail.
+func.func @negative_read_empty_non_identity_map() -> vector<64x128xf16> {
+  %c0 = arith.constant 0 : index
+  %cst = arith.constant 0.0 : f16
+  %e = tensor.empty() : tensor<128x64xf16>
+  %r = vector.transfer_read %e[%c0, %c0], %cst
+     {in_bounds = [true, true], permutation_map = affine_map<(d0, d1) -> (d1, d0)>}
+     : tensor<128x64xf16>, vector<64x128xf16>
+  return %r : vector<64x128xf16>
+}
+// CHECK-LABEL: func.func @negative_read_empty_non_identity_map
+// CHECK:         tensor.empty
+// CHECK:         vector.transfer_read
+
+// -----
+
+// Unmasked, in-bounds read from tensor.empty -> ub.poison.
+func.func @fold_read_empty_unmasked_inbounds() -> vector<128xf16> {
+  %c0 = arith.constant 0 : index
+  %cst = arith.constant 0.0 : f16
+  %e = tensor.empty() : tensor<128xf16>
+  %r = vector.transfer_read %e[%c0], %cst {in_bounds = [true]}
+     : tensor<128xf16>, vector<128xf16>
+  return %r : vector<128xf16>
+}
+// CHECK-LABEL: func.func @fold_read_empty_unmasked_inbounds
+// CHECK-NOT:     tensor.empty
+// CHECK-NOT:     vector.transfer_read
+// CHECK:         %[[POISON:.*]] = ub.poison : vector<128xf16>
+// CHECK:         return %[[POISON]]
+
+// -----
+
+// Unmasked, out-of-bounds read from tensor.empty -> ub.poison.
+func.func @fold_read_empty_unmasked_outofbounds() -> vector<256xf16> {
+  %c0 = arith.constant 0 : index
+  %cst = arith.constant 0.0 : f16
+  %e = tensor.empty() : tensor<128xf16>
+  %r = vector.transfer_read %e[%c0], %cst
+     : tensor<128xf16>, vector<256xf16>
+  return %r : vector<256xf16>
+}
+// CHECK-LABEL: func.func @fold_read_empty_unmasked_outofbounds
+// CHECK-NOT:     tensor.empty
+// CHECK-NOT:     vector.transfer_read
+// CHECK:         %[[PAD:.+]] = arith.constant dense<0.000000e+00> : vector<256xf16>
+// CHECK:         return %[[PAD]]
+
+// -----
+
+// Masked read from tensor.empty where padding is ub.poison -> just ub.poison.
+func.func @fold_read_empty_masked_poison_pad(%mask: vector<128xi1>) -> vector<128xf16> {
+  %c0 = arith.constant 0 : index
+  %pad = ub.poison : f16
+  %e = tensor.empty() : tensor<128xf16>
+  %r = vector.transfer_read %e[%c0], %pad, %mask {in_bounds = [true]}
+     : tensor<128xf16>, vector<128xf16>
+  return %r : vector<128xf16>
+}
+// CHECK-LABEL: func.func @fold_read_empty_masked_poison_pad
+// CHECK-NOT:     tensor.empty
+// CHECK-NOT:     vector.transfer_read
+// CHECK-NOT:     arith.select
+// CHECK:         %[[POISON:.*]] = ub.poison : vector<128xf16>
+// CHECK:         return %[[POISON]]
+
+// -----
+
+// Masked read from tensor.empty with a concrete pad value -> select(mask, poison, broadcast(pad)).
+// Followed by: select cond, poison, X -> X
+func.func @fold_read_empty_masked_real_pad(%mask: vector<128xi1>) -> vector<128xf16> {
+  %c0 = arith.constant 0 : index
+  %cst = arith.constant 0.0 : f16
+  %e = tensor.empty() : tensor<128xf16>
+  %r = vector.transfer_read %e[%c0], %cst, %mask {in_bounds = [true]}
+     : tensor<128xf16>, vector<128xf16>
+  return %r : vector<128xf16>
+}
+// CHECK-LABEL: func.func @fold_read_empty_masked_real_pad
+// CHECK:         %[[CST:.*]] = arith.constant dense<0.000000e+00> : vector<128xf16>
+// CHECK:         return %[[CST]]
+
+// -----
+
+// Unmasked read from a dynamically-shaped tensor.empty -> ub.poison.
+func.func @fold_read_empty_dynamic_unmasked(%sz: index) -> vector<128xf16> {
+  %c0 = arith.constant 0 : index
+  %cst = arith.constant 0.0 : f16
+  %e = tensor.empty(%sz) : tensor<?xf16>
+  %r = vector.transfer_read %e[%c0], %cst {in_bounds = [true]}
+     : tensor<?xf16>, vector<128xf16>
+  return %r : vector<128xf16>
+}
+// CHECK-LABEL: func.func @fold_read_empty_dynamic_unmasked
+// CHECK-NOT:     tensor.empty
+// CHECK-NOT:     vector.transfer_read
+// CHECK:         %[[POISON:.*]] = ub.poison : vector<128xf16>
+// CHECK:         return %[[POISON]]
+
+// -----
+
+// Multiple chained gathers: the index vectors computed for the first two
+// gathers (and the clamp ops derived from their results) must be reused
+// directly as vector SSA values by subsequent gathers. Vectorization may
+// introduce tensor.empty<...xindex> intermediaries with write-read chains
+// for materialized index vectors; these are cleaned up by the
+// optimize-tensor-insert-extract-slices pass that follows.
+
+#map = affine_map<(d0, d1, d2)[s0, s1, s2] -> (s0, s1, s2, 0)>
+#map1 = affine_map<(d0, d1, d2)[s0, s1, s2] -> ()>
+#map2 = affine_map<(d0, d1, d2)[s0, s1, s2] -> (d0, d1, d2)>
+#map3 = affine_map<(d0, d1, d2)[s0, s1, s2] -> (d2)>
+#map4 = affine_map<(d0, d1, d2)[s0, s1, s2] -> (s0, s1, s2)>
+module {
+  func.func @three_gathers_index_materialization(%arg0: tensor<1x8x?xf32>, %arg1: tensor<1x8xf32>, %arg2: tensor<1x8xf32>, %arg3: tensor<1x8x?xf32>, %arg4: tensor<50x32x25x2xi32>, %arg5: tensor<50x40x40xi8>, %arg6: index, %arg7: index, %arg8: index, %arg9: index) -> tensor<1x8x?xf32> {
+    %cst = arith.constant dense<0> : vector<1x8x8xi8>
+    %0 = ub.poison : i8
+    %cst_0 = arith.constant dense<39> : vector<1x8x8xindex>
+    %cst_1 = arith.constant dense<0> : vector<1x8x8xindex>
+    %1 = ub.poison : i32
+    %cst_2 = arith.constant dense<0.000000e+00> : vector<1x8x8xf32>
+    %2 = ub.poison : index
+    %3 = ub.poison : f32
+    %c2 = arith.constant 2 : index
+    %c8 = arith.constant 8 : index
+    %c0 = arith.constant 0 : index
+    %c1 = arith.constant 1 : index
+    %4 = tensor.empty() : tensor<1x8x8xf32>
+    %5 = tensor.empty() : tensor<1x8x8xindex>
+    %6 = tensor.empty() : tensor<1x8x8xindex>
+    %7 = tensor.empty() : tensor<1x8x8xindex>
+    %dim = tensor.dim %arg0, %c2 : tensor<1x8x?xf32>
+    %8 = vector.transfer_read %arg1[%c0, %c0], %3 {in_bounds = [true, true]} : tensor<1x8xf32>, vector<1x8xf32>
+    %9 = vector.transfer_read %arg2[%c0, %c0], %3 {in_bounds = [true, true]} : tensor<1x8xf32>, vector<1x8xf32>
+    %10 = vector.create_mask %c1, %c8, %dim : vector<1x8x8xi1>
+    %11 = vector.transfer_read %arg0[%c0, %c0, %c0], %3, %10 {in_bounds = [true, true, true]} : tensor<1x8x?xf32>, vector<1x8x8xf32>
+    %12 = arith.divf %8, %9 : vector<1x8xf32>
+    %13 = vector.broadcast %12 : vector<1x8xf32> to vector<8x1x8xf32>
+    %14 = vector.transpose %13, [1, 2, 0] : vector<8x1x8xf32> to vector<1x8x8xf32>
+    %15 = arith.cmpf une, %11, %cst_2 : vector<1x8x8xf32>
+    %16 = arith.select %15, %14, %cst_2 : vector<1x8x8xi1>, vector<1x8x8xf32>
+    %17 = arith.addi %arg6, %arg7 : index
+    %18 = vector.step : vector<8xindex>
+    %19 = vector.broadcast %18 : vector<8xindex> to vector<1x8x8xindex>
+    %20 = vector.transpose %19, [0, 2, 1] : vector<1x8x8xindex> to vector<1x8x8xindex>
+    %21 = vector.broadcast %arg8 : index to vector<1x8x8xindex>
+    %22 = arith.addi %21, %20 : vector<1x8x8xindex>
+    %23 = vector.step : vector<8xindex>
+    %24 = vector.broadcast %arg9 : index to vector<8xindex>
+    %25 = arith.addi %24, %23 : vector<8xindex>
+    %26 = vector.transfer_write %16, %4[%c0, %c0, %c0], %10 {in_bounds = [true, true, true]} : vector<1x8x8xf32>, tensor<1x8x8xf32>
+    %27 = vector.broadcast %17 : index to vector<1x8x8xindex>
+    %28 = vector.transfer_write %27, %5[%c0, %c0, %c0], %10 {in_bounds = [true, true, true]} : vector<1x8x8xindex>, tensor<1x8x8xindex>
+    %29 = vector.transfer_write %22, %6[%c0, %c0, %c0], %10 {in_bounds = [true, true, true]} : vector<1x8x8xindex>, tensor<1x8x8xindex>
+    %30 = vector.broadcast %25 : vector<8xindex> to vector<1x8x8xindex>
+    %31 = vector.transfer_write %30, %7[%c0, %c0, %c0], %10 {in_bounds = [true, true, true]} : vector<1x8x8xindex>, tensor<1x8x8xindex>
+    %32 = iree_vector_ext.transfer_gather %arg4[%c0, %c0, %c0, %c0] [%17, %22, %25 : index, vector<1x8x8xindex>, vector<8xindex>], %1 {indexing_maps = [#map, #map1, #map2, #map3]} : tensor<50x32x25x2xi32>, vector<1x8x8xi32>
+    %33 = arith.index_cast %32 : vector<1x8x8xi32> to vector<1x8x8xindex>
+    %34 = vector.transfer_read %28[%c0, %c0, %c0], %2 {in_bounds = [true, true, true]} : tensor<1x8x8xindex>, vector<1x8x8xindex>
+    %35 = vector.transfer_read %29[%c0, %c0, %c0], %2 {in_bounds = [true, true, true]} : tensor<1x8x8xindex>, vector<1x8x8xindex>
+    %36 = vector.transfer_read %31[%c0, %c0, %c0], %2 {in_bounds = [true, true, true]} : tensor<1x8x8xindex>, vector<1x8x8xindex>
+    %37 = iree_vector_ext.transfer_gather %arg4[%c0, %c0, %c0, %c1] [%34, %35, %36 : vector<1x8x8xindex>, vector<1x8x8xindex>, vector<1x8x8xindex>], %1 {indexing_maps = [#map, #map2, #map2, #map2]} : tensor<50x32x25x2xi32>, vector<1x8x8xi32>
+    %38 = arith.index_cast %37 : vector<1x8x8xi32> to vector<1x8x8xindex>
+    %39 = arith.maxsi %33, %cst_1 : vector<1x8x8xindex>
+    %40 = arith.minui %39, %cst_0 : vector<1x8x8xindex>
+    %41 = arith.maxsi %38, %cst_1 : vector<1x8x8xindex>
+    %42 = arith.minui %41, %cst_0 : vector<1x8x8xindex>
+    %43 = vector.transfer_read %28[%c0, %c0, %c0], %2 {in_bounds = [true, true, true]} : tensor<1x8x8xindex>, vector<1x8x8xindex>
+    %44 = iree_vector_ext.transfer_gather %arg5[%c0, %c0, %c0] [%43, %40, %42 : vector<1x8x8xindex>, vector<1x8x8xindex>, vector<1x8x8xindex>], %0 {indexing_maps = [#map4, #map2, #map2, #map2]} : tensor<50x40x40xi8>, vector<1x8x8xi8>
+    %45 = vector.transfer_read %26[%c0, %c0, %c0], %3 {in_bounds = [true, true, true]} : tensor<1x8x8xf32>, vector<1x8x8xf32>
+    %46 = arith.cmpi ugt, %44, %cst : vector<1x8x8xi8>
+    %47 = arith.select %46, %45, %cst_2 : vector<1x8x8xi1>, vector<1x8x8xf32>
+    %48 = vector.transfer_write %47, %arg3[%c0, %c0, %c0] {in_bounds = [true, true, false]} : vector<1x8x8xf32>, tensor<1x8x?xf32>
+    return %48 : tensor<1x8x?xf32>
+  }
+}
+
+// Verify three transfer_gather ops are produced. Index vectors from the first
+// two gathers (and the clamp ops on their results) feed directly into the
+// third gather as vector SSA values — no tensor.empty<...xindex> or
+// write-read chains.
+//
+// CHECK-LABEL: func.func @three_gathers_index_materialization
+// CHECK-SAME:    %[[IN0:[a-zA-Z0-9]+]]: tensor<1x8x?xf32>
+// CHECK-SAME:    %[[IN1:[a-zA-Z0-9]+]]: tensor<1x8xf32>
+// CHECK-SAME:    %[[IN2:[a-zA-Z0-9]+]]: tensor<1x8xf32>
+// CHECK-SAME:    %[[OUT:[a-zA-Z0-9]+]]: tensor<1x8x?xf32>
+// CHECK-SAME:    %[[TABLE:[a-zA-Z0-9]+]]: tensor<50x32x25x2xi32>
+// CHECK-SAME:    %[[LUT:[a-zA-Z0-9]+]]: tensor<50x40x40xi8>
+//
+//     No index-typed tensors should appear anywhere in the output.
+// CHECK-NOT:     tensor<{{.*}}xindex>
+//
+//     Gather #1 from %indir_table — index vecs consumed directly.
+// CHECK:         %[[G1:.+]] = iree_vector_ext.transfer_gather %[[TABLE]]
+// CHECK-SAME:      : tensor<50x32x25x2xi32>, vector<1x8x8xi32>
+// CHECK-NOT:     tensor<{{.*}}xindex>
+// CHECK:         %[[G1_IDX:.+]] = arith.index_cast %[[G1]] : vector<1x8x8xi32> to vector<1x8x8xindex>
+// CHECK-NOT:     tensor<{{.*}}xindex>
+//
+//     Gather #2 from %indir_table — reuses the same index vecs as #1.
+// CHECK:         %[[G2:.+]] = iree_vector_ext.transfer_gather %[[TABLE]]
+// CHECK-SAME:      : tensor<50x32x25x2xi32>, vector<1x8x8xi32>
+// CHECK-NOT:     tensor<{{.*}}xindex>
+// CHECK:         %[[G2_IDX:.+]] = arith.index_cast %[[G2]] : vector<1x8x8xi32> to vector<1x8x8xindex>
+// CHECK-NOT:     tensor<{{.*}}xindex>
+//
+//     Clamp results of gather #1 and #2 — pure vector ops, no tensor roundtrip.
+//     G1_IDX -> maxsi -> minui = CLAMP_A, G2_IDX -> maxsi -> minui = CLAMP_B.
+// CHECK:         %[[G1_MAX:.+]] = arith.maxsi %[[G1_IDX]], {{.*}} : vector<1x8x8xindex>
+// CHECK-NOT:     tensor<{{.*}}xindex>
+// CHECK:         %[[CLAMP_A:.+]] = arith.minui %[[G1_MAX]], {{.*}} : vector<1x8x8xindex>
+// CHECK-NOT:     tensor<{{.*}}xindex>
+// CHECK:         %[[G2_MAX:.+]] = arith.maxsi %[[G2_IDX]], {{.*}} : vector<1x8x8xindex>
+// CHECK-NOT:     tensor<{{.*}}xindex>
+// CHECK:         %[[CLAMP_B:.+]] = arith.minui %[[G2_MAX]], {{.*}} : vector<1x8x8xindex>
+// CHECK-NOT:     tensor<{{.*}}xindex>
+//
+//     Gather #3 from %lut — takes clamped results directly as index vectors.
+// CHECK:         %[[G3:.+]] = iree_vector_ext.transfer_gather %[[LUT]]
+// CHECK-SAME:      [{{.*}}, %[[CLAMP_A]], %[[CLAMP_B]] : {{.*}}]
+// CHECK-SAME:      : tensor<50x40x40xi8>, vector<1x8x8xi8>
+// CHECK-NOT:     tensor<{{.*}}xindex>
+//
+//     Final select + write.
+// CHECK:         %[[GATE:.+]] = arith.cmpi ugt, %[[G3]]
+// CHECK-NOT:     tensor<{{.*}}xindex>
+// CHECK:         %[[RES:.+]] = arith.select %[[GATE]]
+// CHECK-NOT:     tensor<{{.*}}xindex>
+// CHECK:         vector.transfer_write %[[RES]], %[[OUT]]
+// CHECK-NOT:     tensor<{{.*}}xindex>
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/pipeline_vector_distribute_reduction_gfx942.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/pipeline_vector_distribute_reduction_gfx942.mlir
index 4b04ef0..c84f034 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/pipeline_vector_distribute_reduction_gfx942.mlir
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/pipeline_vector_distribute_reduction_gfx942.mlir
@@ -289,6 +289,11 @@
   %6 = iree_tensor_ext.dispatch.tensor.load %2, offsets = [0, 0, 0], sizes = [20, 4096, 64], strides = [1, 1, 1] : !iree_tensor_ext.dispatch.tensor<readonly:tensor<20x4096x64xf16>> -> tensor<20x4096x64xf16>
   %7 = tensor.empty() : tensor<20x1x64xf32>
   %8 = tensor.empty() : tensor<20x1xf32>
+  %cst_zero = arith.constant 0.000000e+00 : f32
+  %cst_neg_inf = arith.constant 0xFF800000 : f32
+  %acc_fill = linalg.fill ins(%cst_zero : f32) outs(%7: tensor<20x1x64xf32>) -> tensor<20x1x64xf32>
+  %max_fill = linalg.fill ins(%cst_neg_inf : f32) outs(%8: tensor<20x1xf32>) -> tensor<20x1xf32>
+  %sum_fill = linalg.fill ins(%cst_zero : f32) outs(%8: tensor<20x1xf32>) -> tensor<20x1xf32>
   %9:3 = iree_linalg_ext.online_attention  {indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>,
                affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d2)>,
                affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d4)>,
@@ -301,7 +306,7 @@
                 qk_attrs = {lowering_config = #qk_config},
                 pv_attrs = {lowering_config = #pv_config}
                }}
-               ins(%4, %5, %6, %cst : tensor<20x1x64xf16>, tensor<20x4096x64xf16>, tensor<20x4096x64xf16>, f16) outs(%7, %8, %8 : tensor<20x1x64xf32>, tensor<20x1xf32>, tensor<20x1xf32>) {
+               ins(%4, %5, %6, %cst : tensor<20x1x64xf16>, tensor<20x4096x64xf16>, tensor<20x4096x64xf16>, f16) outs(%acc_fill, %max_fill, %sum_fill : tensor<20x1x64xf32>, tensor<20x1xf32>, tensor<20x1xf32>) {
                 ^bb0(%score: f32):
                   iree_linalg_ext.yield %score : f32
                } -> tensor<20x1x64xf32>, tensor<20x1xf32>, tensor<20x1xf32>
@@ -464,6 +469,13 @@
   %13 = tensor.empty(%6) : tensor<4x1x1x?x32xf16>
   %14 = tensor.empty(%6) : tensor<4x?x1x32x128xf16>
   %15 = tensor.empty() : tensor<4x1x1xf32>
+
+  %cst_acc = arith.constant 0.000000e+00 : f32
+  %cst_max = arith.constant 0xFF800000 : f32
+  %acc_fill = linalg.fill ins(%cst_acc : f32) outs(%12 : tensor<4x1x1x128xf32>) -> tensor<4x1x1x128xf32>
+  %max_fill = linalg.fill ins(%cst_max : f32) outs(%15 : tensor<4x1x1xf32>) -> tensor<4x1x1xf32>
+  %sum_fill = linalg.fill ins(%cst_acc : f32) outs(%15 : tensor<4x1x1xf32>) -> tensor<4x1x1xf32>
+
   %16 = iree_tensor_ext.dispatch.tensor.load %3, offsets = [0, 0, 0, 0, 0, 0], sizes = [4096, 1, 1, 1, 32, 128], strides = [1, 1, 1, 1, 1, 1] : !iree_tensor_ext.dispatch.tensor<readonly:tensor<4096x1x2x1x32x128xf16>> -> tensor<4096x1x32x128xf16>
   %17 = iree_linalg_ext.gather dimension_map = [0] ins(%16, %9 : tensor<4096x1x32x128xf16>, tensor<4x?xi64>) outs(%14 : tensor<4x?x1x32x128xf16>) -> tensor<4x?x1x32x128xf16>
   %18 = iree_linalg_ext.gather dimension_map = [0] ins(%16, %10 : tensor<4096x1x32x128xf16>, tensor<4x?xi64>) outs(%14 : tensor<4x?x1x32x128xf16>) -> tensor<4x?x1x32x128xf16>
@@ -500,7 +512,7 @@
       lowering_config = #attention_lowering_config
     }
      ins(%11, %17, %18, %cst, %19 : tensor<4x1x1x128xf16>, tensor<4x?x1x32x128xf16>, tensor<4x?x1x32x128xf16>, f16, tensor<4x1x1x?x32xf16>)
-    outs(%12, %15, %15 : tensor<4x1x1x128xf32>, tensor<4x1x1xf32>, tensor<4x1x1xf32>) {
+    outs(%acc_fill, %max_fill, %sum_fill : tensor<4x1x1x128xf32>, tensor<4x1x1xf32>, tensor<4x1x1xf32>) {
       ^bb0(%arg0: f32):
         iree_linalg_ext.yield %arg0 : f32
   } -> tensor<4x1x1x128xf32>, tensor<4x1x1xf32>, tensor<4x1x1xf32>