[DispatchCreation] Skip collapse for scatter-like generics with tensor.extract (#24034)

Skip collapsing `linalg.generic` ops that contain both tensor.extract
and linalg.index operations. These are strided scatter patterns
(produced by
[ConvertStridedInsertSliceToGeneric](https://github.com/iree-org/iree/pull/23990))
where 1D collapse introduces expensive delinearization (div/mod chains)
on the extract indices, causing ~3.8x regression. Keeping the
multi-dimensional iteration space allows direct workgroup tiling without
delinearization.

ci-extra: test_torch

Signed-off-by: yzhang93 <zhyuhang88@gmail.com>
diff --git a/compiler/src/iree/compiler/DispatchCreation/CollapseDimensions.cpp b/compiler/src/iree/compiler/DispatchCreation/CollapseDimensions.cpp
index c0106ec..29222d3 100644
--- a/compiler/src/iree/compiler/DispatchCreation/CollapseDimensions.cpp
+++ b/compiler/src/iree/compiler/DispatchCreation/CollapseDimensions.cpp
@@ -221,6 +221,27 @@
     return false;
   }
 
+  // Skip collapse for scatter-like generics that use tensor.extract with
+  // linalg.index-based addressing. These are strided scatter patterns where
+  // 1D collapse introduces expensive delinearization (div/mod chains) that
+  // dominates execution. Keeping the multi-dimensional iteration space
+  // allows direct workgroup tiling without delinearization.
+  {
+    bool hasTensorExtract = false;
+    bool hasLinalgIndex = false;
+    genericOp.getBlock()->walk([&](Operation *inner) {
+      if (isa<tensor::ExtractOp>(inner)) {
+        hasTensorExtract = true;
+      }
+      if (isa<linalg::IndexOp>(inner)) {
+        hasLinalgIndex = true;
+      }
+    });
+    if (hasTensorExtract && hasLinalgIndex) {
+      return false;
+    }
+  }
+
   auto hasEncoding = [](Type type) -> bool {
     auto rankedTensorType = dyn_cast<RankedTensorType>(type);
     if (!rankedTensorType || !rankedTensorType.getEncoding()) {
diff --git a/compiler/src/iree/compiler/DispatchCreation/test/collapse_dimensions.mlir b/compiler/src/iree/compiler/DispatchCreation/test/collapse_dimensions.mlir
index feb40bd..e98b57e 100644
--- a/compiler/src/iree/compiler/DispatchCreation/test/collapse_dimensions.mlir
+++ b/compiler/src/iree/compiler/DispatchCreation/test/collapse_dimensions.mlir
@@ -953,3 +953,48 @@
 //       CHECK:   %[[GEN1:.+]] = linalg.generic
 //  CHECK-SAME:     ins(%[[GEN0]]
 //       CHECK:   tensor.parallel_insert_slice %[[GEN1]]
+
+// -----
+
+// Scatter-like generics with tensor.extract + linalg.index should NOT be
+// collapsed. Collapsing introduces expensive delinearization (div/mod) on
+// the index-computed extract indices.
+#map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
+util.func public @no_collapse_scatter_generic(%src: tensor<1x25x25x32xf16>) -> tensor<1x52x52x32xf16> {
+  %zero = arith.constant 0.000000e+00 : f16
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %c2 = arith.constant 2 : index
+  %c25 = arith.constant 25 : index
+  %empty = tensor.empty() : tensor<1x52x52x32xf16>
+  %result = linalg.generic {
+    indexing_maps = [#map],
+    iterator_types = ["parallel", "parallel", "parallel", "parallel"]
+  } outs(%empty : tensor<1x52x52x32xf16>) {
+  ^bb0(%out: f16):
+    %n = linalg.index 0 : index
+    %h = linalg.index 1 : index
+    %w = linalg.index 2 : index
+    %c_idx = linalg.index 3 : index
+    %sh = arith.subi %h, %c1 : index
+    %rem_h = arith.remsi %sh, %c2 : index
+    %div_h = arith.divsi %sh, %c2 : index
+    %ge = arith.cmpi sge, %sh, %c0 : index
+    %eq = arith.cmpi eq, %rem_h, %c0 : index
+    %lt = arith.cmpi slt, %div_h, %c25 : index
+    %valid = arith.andi %ge, %eq : i1
+    %valid2 = arith.andi %valid, %lt : i1
+    %clamped = arith.maxsi %div_h, %c0 : index
+    %extracted = tensor.extract %src[%n, %clamped, %w, %c_idx] : tensor<1x25x25x32xf16>
+    %val = arith.select %valid2, %extracted, %zero : f16
+    linalg.yield %val : f16
+  } -> tensor<1x52x52x32xf16>
+  util.return %result : tensor<1x52x52x32xf16>
+}
+
+// CHECK-LABEL: @no_collapse_scatter_generic
+// The generic should NOT be collapsed — it should retain 4 parallel dims.
+// CHECK:       linalg.generic
+// CHECK-SAME:    iterator_types = ["parallel", "parallel", "parallel", "parallel"]
+// CHECK:         tensor.extract
+// CHECK:         arith.select