[DispatchCreation] Skip collapse for scatter-like generics with tensor.extract (#24034)
Skip collapsing `linalg.generic` ops that contain both tensor.extract
and linalg.index operations. These are strided scatter patterns
(produced by
[ConvertStridedInsertSliceToGeneric](https://github.com/iree-org/iree/pull/23990))
where 1D collapse introduces expensive delinearization (div/mod chains)
on the extract indices, causing ~3.8x regression. Keeping the
multi-dimensional iteration space allows direct workgroup tiling without
delinearization.
ci-extra: test_torch
Signed-off-by: yzhang93 <zhyuhang88@gmail.com>
diff --git a/compiler/src/iree/compiler/DispatchCreation/CollapseDimensions.cpp b/compiler/src/iree/compiler/DispatchCreation/CollapseDimensions.cpp
index c0106ec..29222d3 100644
--- a/compiler/src/iree/compiler/DispatchCreation/CollapseDimensions.cpp
+++ b/compiler/src/iree/compiler/DispatchCreation/CollapseDimensions.cpp
@@ -221,6 +221,27 @@
return false;
}
+ // Skip collapse for scatter-like generics that use tensor.extract with
+ // linalg.index-based addressing. These are strided scatter patterns where
+ // 1D collapse introduces expensive delinearization (div/mod chains) that
+ // dominates execution. Keeping the multi-dimensional iteration space
+ // allows direct workgroup tiling without delinearization.
+ {
+ bool hasTensorExtract = false;
+ bool hasLinalgIndex = false;
+ genericOp.getBlock()->walk([&](Operation *inner) {
+ if (isa<tensor::ExtractOp>(inner)) {
+ hasTensorExtract = true;
+ }
+ if (isa<linalg::IndexOp>(inner)) {
+ hasLinalgIndex = true;
+ }
+ });
+ if (hasTensorExtract && hasLinalgIndex) {
+ return false;
+ }
+ }
+
auto hasEncoding = [](Type type) -> bool {
auto rankedTensorType = dyn_cast<RankedTensorType>(type);
if (!rankedTensorType || !rankedTensorType.getEncoding()) {
diff --git a/compiler/src/iree/compiler/DispatchCreation/test/collapse_dimensions.mlir b/compiler/src/iree/compiler/DispatchCreation/test/collapse_dimensions.mlir
index feb40bd..e98b57e 100644
--- a/compiler/src/iree/compiler/DispatchCreation/test/collapse_dimensions.mlir
+++ b/compiler/src/iree/compiler/DispatchCreation/test/collapse_dimensions.mlir
@@ -953,3 +953,48 @@
// CHECK: %[[GEN1:.+]] = linalg.generic
// CHECK-SAME: ins(%[[GEN0]]
// CHECK: tensor.parallel_insert_slice %[[GEN1]]
+
+// -----
+
+// Scatter-like generics with tensor.extract + linalg.index should NOT be
+// collapsed. Collapsing introduces expensive delinearization (div/mod) on
+// the index-computed extract indices.
+#map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
+util.func public @no_collapse_scatter_generic(%src: tensor<1x25x25x32xf16>) -> tensor<1x52x52x32xf16> {
+ %zero = arith.constant 0.000000e+00 : f16
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %c2 = arith.constant 2 : index
+ %c25 = arith.constant 25 : index
+ %empty = tensor.empty() : tensor<1x52x52x32xf16>
+ %result = linalg.generic {
+ indexing_maps = [#map],
+ iterator_types = ["parallel", "parallel", "parallel", "parallel"]
+ } outs(%empty : tensor<1x52x52x32xf16>) {
+ ^bb0(%out: f16):
+ %n = linalg.index 0 : index
+ %h = linalg.index 1 : index
+ %w = linalg.index 2 : index
+ %c_idx = linalg.index 3 : index
+ %sh = arith.subi %h, %c1 : index
+ %rem_h = arith.remsi %sh, %c2 : index
+ %div_h = arith.divsi %sh, %c2 : index
+ %ge = arith.cmpi sge, %sh, %c0 : index
+ %eq = arith.cmpi eq, %rem_h, %c0 : index
+ %lt = arith.cmpi slt, %div_h, %c25 : index
+ %valid = arith.andi %ge, %eq : i1
+ %valid2 = arith.andi %valid, %lt : i1
+ %clamped = arith.maxsi %div_h, %c0 : index
+ %extracted = tensor.extract %src[%n, %clamped, %w, %c_idx] : tensor<1x25x25x32xf16>
+ %val = arith.select %valid2, %extracted, %zero : f16
+ linalg.yield %val : f16
+ } -> tensor<1x52x52x32xf16>
+ util.return %result : tensor<1x52x52x32xf16>
+}
+
+// CHECK-LABEL: @no_collapse_scatter_generic
+// The generic should NOT be collapsed — it should retain 4 parallel dims.
+// CHECK: linalg.generic
+// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "parallel"]
+// CHECK: tensor.extract
+// CHECK: arith.select