[Preprocessing] Add pass to convert strided insert_slice to linalg.generic (#23990)

Add `iree-preprocessing-convert-strided-insert-slice-to-generic` pass
that converts strided `tensor.insert_slice` into zero-constant
destinations to a `linalg.generic` with index arithmetic. This targets
backward data convolution patterns where the upstream gradient is
scattered into a zero buffer at strided positions.

The replacement generic computes the strided scatter in a single
dispatch, replacing the `Memset + slow_memcpy` pair. Power-of-2 strides
use bitwise ops instead of expensive integer division. The result is
selected via `arith.select` for branchless GPU execution.

This PR is a partial implementation for
https://github.com/iree-org/iree/issues/23976.

---------

Signed-off-by: yzhang93 <zhyuhang88@gmail.com>
Co-authored-by: Claude <noreply@anthropic.com>
diff --git a/compiler/src/iree/compiler/Preprocessing/Common/BUILD.bazel b/compiler/src/iree/compiler/Preprocessing/Common/BUILD.bazel
index 29fb7bb..341a9bf 100644
--- a/compiler/src/iree/compiler/Preprocessing/Common/BUILD.bazel
+++ b/compiler/src/iree/compiler/Preprocessing/Common/BUILD.bazel
@@ -40,6 +40,7 @@
         "Convert1X1FilterConv2DToMatmul.cpp",
         "ConvertConvFilterToChannelsLast.cpp",
         "ConvertConvToChannelsLast.cpp",
+        "ConvertStridedInsertSliceToGeneric.cpp",
         "FoldAttentionWithTranspose.cpp",
         "GeneralizeLinalgMatMul.cpp",
         "Interpreter.cpp",
diff --git a/compiler/src/iree/compiler/Preprocessing/Common/CMakeLists.txt b/compiler/src/iree/compiler/Preprocessing/Common/CMakeLists.txt
index d2196ba..73148a7 100644
--- a/compiler/src/iree/compiler/Preprocessing/Common/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Preprocessing/Common/CMakeLists.txt
@@ -31,6 +31,7 @@
     "Convert1X1FilterConv2DToMatmul.cpp"
     "ConvertConvFilterToChannelsLast.cpp"
     "ConvertConvToChannelsLast.cpp"
+    "ConvertStridedInsertSliceToGeneric.cpp"
     "FoldAttentionWithTranspose.cpp"
     "GeneralizeLinalgMatMul.cpp"
     "Interpreter.cpp"
diff --git a/compiler/src/iree/compiler/Preprocessing/Common/ConvertStridedInsertSliceToGeneric.cpp b/compiler/src/iree/compiler/Preprocessing/Common/ConvertStridedInsertSliceToGeneric.cpp
new file mode 100644
index 0000000..5d0bbee
--- /dev/null
+++ b/compiler/src/iree/compiler/Preprocessing/Common/ConvertStridedInsertSliceToGeneric.cpp
@@ -0,0 +1,247 @@
+// Copyright 2026 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "iree/compiler/Preprocessing/Common/Passes.h"
+#include "llvm/ADT/TypeSwitch.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Linalg/IR/Linalg.h"
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/IR/Matchers.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/WalkPatternRewriteDriver.h"
+
+namespace mlir::iree_compiler::Preprocessing {
+
+#define GEN_PASS_DEF_CONVERTSTRIDEDINSERTSLICETOGENERICPASS
+#include "iree/compiler/Preprocessing/Common/Passes.h.inc"
+
+namespace {
+
+/// Convert a strided `tensor.insert_slice` into a zero-constant destination
+/// to a `linalg.generic` with index arithmetic.
+///
+/// This targets the backward data convolution pattern where the upstream
+/// gradient is scattered into a zero buffer at strided positions, producing
+/// a separate Memset + slow_memcpy dispatch pair. The replacement generic
+/// computes the strided scatter in a single pass and is potentially
+/// fusable with consumer ops in later dispatch formation passes.
+///
+/// For each output position, the generic checks whether the position maps
+/// to a valid source element (i.e., (pos - offset) is non-negative,
+/// divisible by stride, and the quotient is in-bounds). Source indices are
+/// clamped to valid range so the extract is always safe, and arith.select
+/// chooses between the extracted value and zero.
+class ConvertStridedInsertSliceToGeneric
+    : public OpRewritePattern<tensor::InsertSliceOp> {
+public:
+  using Base::Base;
+  LogicalResult matchAndRewrite(tensor::InsertSliceOp op,
+                                PatternRewriter &rewriter) const override {
+    // Destination must be a zero splat constant.
+    Value dest = op.getDest();
+    Attribute destAttr;
+    if (!matchPattern(dest, m_Constant(&destAttr))) {
+      return failure();
+    }
+    auto splatAttr = dyn_cast<SplatElementsAttr>(destAttr);
+    if (!splatAttr) {
+      return failure();
+    }
+    Attribute splatVal = splatAttr.getSplatValue<Attribute>();
+    bool isZero =
+        TypeSwitch<Attribute, bool>(splatVal)
+            .Case<FloatAttr>([](auto a) { return a.getValue().isZero(); })
+            .Case<IntegerAttr>([](auto a) { return a.getValue().isZero(); })
+            .Default([](auto) { return false; });
+    if (!isZero) {
+      return failure();
+    }
+
+    // All offsets, sizes, and strides must be static, with at least one
+    // non-unit stride.
+    SmallVector<int64_t> offsets(op.getStaticOffsets());
+    SmallVector<int64_t> strides(op.getStaticStrides());
+    SmallVector<int64_t> sizes(op.getStaticSizes());
+    if (ShapedType::isDynamicShape(offsets) ||
+        ShapedType::isDynamicShape(strides) ||
+        ShapedType::isDynamicShape(sizes)) {
+      return failure();
+    }
+    if (llvm::all_of(strides, [](int64_t s) { return s == 1; })) {
+      return failure();
+    }
+
+    Value src = op.getSource();
+    auto srcTy = cast<RankedTensorType>(src.getType());
+    auto destTy = cast<RankedTensorType>(dest.getType());
+
+    unsigned origRank = destTy.getRank();
+    auto elemTy = destTy.getElementType();
+    Location loc = op.getLoc();
+
+    // A dim is "passthrough" if it has stride 1 and offset 0 — the scatter
+    // generic just copies it without any index arithmetic.
+    auto isPassthrough = [&](unsigned d) {
+      return strides[d] == 1 && offsets[d] == 0;
+    };
+
+    // Skip conversion when too many non-batch passthrough elements exist
+    // per spatial position. In grouped backward-data convolutions the
+    // group+channel dims are passthrough; when their product is large the
+    // scatter generic is slower than hardware DMA (Memset + slow_memcpy).
+    {
+      int64_t passthroughElems = 1;
+      constexpr int64_t kMaxPassthroughElems = 256;
+      for (unsigned i = 1; i < origRank; i++) {
+        if (isPassthrough(i)) {
+          passthroughElems *= destTy.getDimSize(i);
+        }
+      }
+      if (passthroughElems > kMaxPassthroughElems) {
+        return failure();
+      }
+    }
+
+    // Collapse contiguous passthrough dims to reduce iteration rank.
+    // E.g., grouped convs have trailing [G, C] dims that collapse to [G*C].
+    // Without this, high-rank scatter generics (5D+) can fail to compile
+    // in downstream passes and produce suboptimal GPU tiling.
+    SmallVector<ReassociationIndices> reassociation;
+    ReassociationIndices currentGroup = {0};
+    for (unsigned i = 1; i < origRank; i++) {
+      if (isPassthrough(i - 1) && isPassthrough(i)) {
+        currentGroup.push_back(i);
+      } else {
+        reassociation.push_back(std::move(currentGroup));
+        currentGroup = {static_cast<int64_t>(i)};
+      }
+    }
+    reassociation.push_back(std::move(currentGroup));
+    bool needsCollapse =
+        llvm::any_of(reassociation, [](const auto &g) { return g.size() > 1; });
+
+    // Compute collapsed metadata.
+    unsigned rank = reassociation.size();
+    SmallVector<int64_t> collapsedOffsets(rank), collapsedStrides(rank);
+    SmallVector<int64_t> collapsedDestShape(rank), collapsedSrcShape(rank);
+    for (auto [gi, group] : llvm::enumerate(reassociation)) {
+      collapsedOffsets[gi] = offsets[group[0]];
+      collapsedStrides[gi] = strides[group[0]];
+      int64_t destDim = 1, srcDim = 1;
+      for (int64_t idx : group) {
+        destDim *= destTy.getDimSize(idx);
+        srcDim *= srcTy.getDimSize(idx);
+      }
+      collapsedDestShape[gi] = destDim;
+      collapsedSrcShape[gi] = srcDim;
+    }
+    auto collapsedSrcTy = RankedTensorType::get(collapsedSrcShape, elemTy);
+    auto collapsedDestTy = RankedTensorType::get(collapsedDestShape, elemTy);
+
+    Value collapsedSrc = src;
+    if (needsCollapse) {
+      collapsedSrc = tensor::CollapseShapeOp::create(
+          rewriter, loc, collapsedSrcTy, src, reassociation);
+    }
+
+    // Build the scatter generic over the (collapsed) dest shape.
+    Value empty =
+        tensor::EmptyOp::create(rewriter, loc, collapsedDestShape, elemTy);
+    AffineMap identityMap = rewriter.getMultiDimIdentityMap(rank);
+    SmallVector<AffineMap> indexingMaps = {identityMap};
+    SmallVector<utils::IteratorType> iterTypes(rank,
+                                               utils::IteratorType::parallel);
+
+    auto genericOp = linalg::GenericOp::create(
+        rewriter, loc, collapsedDestTy, /*inputs=*/ValueRange{},
+        /*outputs=*/ValueRange{empty}, indexingMaps, iterTypes,
+        [&](OpBuilder &b, Location loc, ValueRange /*args*/) {
+          auto cstIdx = [&](int64_t v) {
+            return arith::ConstantOp::create(b, loc, b.getIndexAttr(v));
+          };
+
+          Value zero = cstIdx(0);
+          Value allValid = nullptr;
+          SmallVector<Value> srcIndices;
+          for (unsigned i = 0; i < rank; i++) {
+            Value idx = linalg::IndexOp::create(b, loc, i);
+
+            // Passthrough dims: source index == dest index.
+            if (collapsedStrides[i] == 1 && collapsedOffsets[i] == 0) {
+              srcIndices.push_back(idx);
+              continue;
+            }
+
+            int64_t stride = collapsedStrides[i];
+            int64_t srcSize = collapsedSrcTy.getDimSize(i);
+            Value shifted =
+                arith::SubIOp::create(b, loc, idx, cstIdx(collapsedOffsets[i]));
+
+            Value strVal = cstIdx(stride);
+            Value rem = arith::RemSIOp::create(b, loc, shifted, strVal);
+            Value srcIdx = arith::DivSIOp::create(b, loc, shifted, strVal);
+
+            // Valid when shifted >= 0 && rem == 0 && srcIdx < srcSize.
+            Value dimValid = arith::AndIOp::create(
+                b, loc,
+                arith::CmpIOp::create(b, loc, arith::CmpIPredicate::sge,
+                                      shifted, zero),
+                arith::CmpIOp::create(b, loc, arith::CmpIPredicate::eq, rem,
+                                      zero));
+            dimValid = arith::AndIOp::create(
+                b, loc, dimValid,
+                arith::CmpIOp::create(b, loc, arith::CmpIPredicate::slt, srcIdx,
+                                      cstIdx(srcSize)));
+            allValid = allValid
+                           ? arith::AndIOp::create(b, loc, allValid, dimValid)
+                           : dimValid;
+
+            // Clamp to [0, srcSize-1] so the extract is always in-bounds.
+            Value clamped = arith::MaxSIOp::create(b, loc, srcIdx, zero);
+            clamped =
+                arith::MinSIOp::create(b, loc, clamped, cstIdx(srcSize - 1));
+            srcIndices.push_back(clamped);
+          }
+
+          Value zeroElem =
+              arith::ConstantOp::create(b, loc, rewriter.getZeroAttr(elemTy));
+          Value extracted =
+              tensor::ExtractOp::create(b, loc, collapsedSrc, srcIndices);
+          Value result = allValid ? arith::SelectOp::create(b, loc, allValid,
+                                                            extracted, zeroElem)
+                                  : extracted;
+          linalg::YieldOp::create(b, loc, result);
+        });
+
+    Value result = genericOp.getResult(0);
+    if (needsCollapse) {
+      result = tensor::ExpandShapeOp::create(rewriter, loc, destTy, result,
+                                             reassociation);
+    }
+    rewriter.replaceOp(op, result);
+    return success();
+  }
+};
+
+struct ConvertStridedInsertSliceToGenericPass
+    : impl::ConvertStridedInsertSliceToGenericPassBase<
+          ConvertStridedInsertSliceToGenericPass> {
+  void getDependentDialects(DialectRegistry &registry) const override {
+    registry.insert<arith::ArithDialect, linalg::LinalgDialect,
+                    tensor::TensorDialect>();
+  }
+
+  void runOnOperation() override {
+    MLIRContext *context = &getContext();
+    RewritePatternSet patterns(context);
+    patterns.insert<ConvertStridedInsertSliceToGeneric>(context);
+    walkAndApplyPatterns(getOperation(), std::move(patterns));
+  }
+};
+
+} // namespace
+} // namespace mlir::iree_compiler::Preprocessing
diff --git a/compiler/src/iree/compiler/Preprocessing/Common/Passes.td b/compiler/src/iree/compiler/Preprocessing/Common/Passes.td
index 6303e6e..f9e1290 100644
--- a/compiler/src/iree/compiler/Preprocessing/Common/Passes.td
+++ b/compiler/src/iree/compiler/Preprocessing/Common/Passes.td
@@ -78,6 +78,11 @@
   ];
 }
 
+def ConvertStridedInsertSliceToGenericPass:
+    Pass<"iree-preprocessing-convert-strided-insert-slice-to-generic", ""> {
+  let summary = "Converts strided insert_slice into zero-constant destinations to linalg.generic with index arithmetic.";
+}
+
 def FoldAttentionWithTransposePass :
     Pass<"iree-preprocessing-fold-attention-with-transpose", ""> {
   let summary = "Fold attention operation with transpose";
diff --git a/compiler/src/iree/compiler/Preprocessing/Common/test/BUILD.bazel b/compiler/src/iree/compiler/Preprocessing/Common/test/BUILD.bazel
index 38583f1..9ae24d7 100644
--- a/compiler/src/iree/compiler/Preprocessing/Common/test/BUILD.bazel
+++ b/compiler/src/iree/compiler/Preprocessing/Common/test/BUILD.bazel
@@ -21,6 +21,7 @@
             "conv1x1_to_matmul.mlir",
             "conv_filter_to_channels_last.mlir",
             "conv_to_channels_last.mlir",
+            "convert_strided_insert_slice_to_generic.mlir",
             "fold_attention_with_transpose.mlir",
             "generalize_linalg_matmul.mlir",
             "make_single_dispatch_for_function.mlir",
diff --git a/compiler/src/iree/compiler/Preprocessing/Common/test/CMakeLists.txt b/compiler/src/iree/compiler/Preprocessing/Common/test/CMakeLists.txt
index e947c57..75d31de 100644
--- a/compiler/src/iree/compiler/Preprocessing/Common/test/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Preprocessing/Common/test/CMakeLists.txt
@@ -18,6 +18,7 @@
     "conv1x1_to_matmul.mlir"
     "conv_filter_to_channels_last.mlir"
     "conv_to_channels_last.mlir"
+    "convert_strided_insert_slice_to_generic.mlir"
     "fold_attention_with_transpose.mlir"
     "generalize_linalg_matmul.mlir"
     "make_single_dispatch_for_function.mlir"
diff --git a/compiler/src/iree/compiler/Preprocessing/Common/test/convert_strided_insert_slice_to_generic.mlir b/compiler/src/iree/compiler/Preprocessing/Common/test/convert_strided_insert_slice_to_generic.mlir
new file mode 100644
index 0000000..9c0088f
--- /dev/null
+++ b/compiler/src/iree/compiler/Preprocessing/Common/test/convert_strided_insert_slice_to_generic.mlir
@@ -0,0 +1,134 @@
+// RUN: iree-opt --split-input-file --mlir-print-local-scope --iree-preprocessing-convert-strided-insert-slice-to-generic %s | FileCheck %s
+
+// Converted: stride-2 with non-zero offsets, no passthrough dims, no collapse.
+// Checks the index arithmetic: sub offset, rem/div for stride check, bounds check, clamp, extract, select.
+util.func public @stride2_no_passthrough(%src: tensor<4x4xf16>) -> tensor<9x9xf16> {
+  %cst = arith.constant dense<0.000000e+00> : tensor<9x9xf16>
+  %0 = tensor.insert_slice %src into %cst[1, 1] [4, 4] [2, 2] : tensor<4x4xf16> into tensor<9x9xf16>
+  util.return %0 : tensor<9x9xf16>
+}
+
+// CHECK-LABEL: @stride2_no_passthrough
+// CHECK-SAME:      %[[SRC:.*]]: tensor<4x4xf16>
+// CHECK-NOT:   tensor.insert_slice
+// CHECK:       %[[GENERIC:.*]] = linalg.generic
+// CHECK-SAME:      iterator_types = ["parallel", "parallel"]
+// CHECK-SAME:      outs({{.*}} : tensor<9x9xf16>)
+// CHECK:         linalg.index 0
+// CHECK:         arith.subi
+// CHECK:         arith.remsi
+// CHECK:         arith.divsi
+// CHECK:         linalg.index 1
+// CHECK:         arith.subi
+// CHECK:         arith.remsi
+// CHECK:         arith.divsi
+// CHECK:         tensor.extract %[[SRC]]
+// CHECK:         arith.select
+// CHECK:         linalg.yield
+// CHECK:       util.return %[[GENERIC]]
+
+// -----
+
+// Converted: stride-3 with non-zero offsets.
+util.func public @stride3_no_passthrough(%src: tensor<3x3xf32>) -> tensor<10x10xf32> {
+  %cst = arith.constant dense<0.000000e+00> : tensor<10x10xf32>
+  %0 = tensor.insert_slice %src into %cst[1, 1] [3, 3] [3, 3] : tensor<3x3xf32> into tensor<10x10xf32>
+  util.return %0 : tensor<10x10xf32>
+}
+
+// CHECK-LABEL: @stride3_no_passthrough
+// CHECK-SAME:      %[[SRC:.*]]: tensor<3x3xf32>
+// CHECK-NOT:   tensor.insert_slice
+// CHECK:       %[[GENERIC:.*]] = linalg.generic
+// CHECK-SAME:      outs({{.*}} : tensor<10x10xf32>)
+// CHECK:         linalg.index 0
+// CHECK:         arith.subi
+// CHECK:         arith.remsi
+// CHECK:         arith.divsi
+// CHECK:         linalg.index 1
+// CHECK:         arith.subi
+// CHECK:         arith.remsi
+// CHECK:         arith.divsi
+// CHECK:         tensor.extract %[[SRC]]
+// CHECK:         arith.select
+// CHECK:       util.return %[[GENERIC]]
+
+// -----
+
+// Converted with dim collapse: 5D input with passthrough trailing dims [3,4]
+// collapsed to 4D. Checks collapse_shape, 4D generic, and expand_shape.
+util.func public @stride2_with_collapse(%src: tensor<1x25x25x4x8xf16>) -> tensor<1x52x52x4x8xf16> {
+  %cst = arith.constant dense<0.000000e+00> : tensor<1x52x52x4x8xf16>
+  %0 = tensor.insert_slice %src into %cst[0, 1, 1, 0, 0] [1, 25, 25, 4, 8] [1, 2, 2, 1, 1] : tensor<1x25x25x4x8xf16> into tensor<1x52x52x4x8xf16>
+  util.return %0 : tensor<1x52x52x4x8xf16>
+}
+
+// CHECK-LABEL: @stride2_with_collapse
+// CHECK-SAME:      %[[SRC:.*]]: tensor<1x25x25x4x8xf16>
+// CHECK-NOT:   tensor.insert_slice
+// Source collapsed: dims [3,4] merged (4*8=32).
+// CHECK:       %[[CSRC:.*]] = tensor.collapse_shape %[[SRC]]
+// CHECK-SAME:      tensor<1x25x25x4x8xf16> into tensor<1x25x25x32xf16>
+// Generic at reduced rank 4.
+// CHECK:       %[[GENERIC:.*]] = linalg.generic
+// CHECK-SAME:      iterator_types = ["parallel", "parallel", "parallel", "parallel"]
+// CHECK-SAME:      outs({{.*}} : tensor<1x52x52x32xf16>)
+// Dim 0 (batch): passthrough.
+// CHECK:         linalg.index 0
+// Dim 1 (H): strided.
+// CHECK:         linalg.index 1
+// CHECK:         arith.subi
+// CHECK:         arith.remsi
+// CHECK:         arith.divsi
+// Dim 2 (W): strided.
+// CHECK:         linalg.index 2
+// CHECK:         arith.subi
+// CHECK:         arith.remsi
+// CHECK:         arith.divsi
+// Dim 3 (collapsed G*C): passthrough.
+// CHECK:         linalg.index 3
+// CHECK:         tensor.extract %[[CSRC]]
+// CHECK:         arith.select
+// CHECK:         linalg.yield
+// Result expanded back to 5D.
+// CHECK:       %[[EXP:.*]] = tensor.expand_shape %[[GENERIC]]
+// CHECK-SAME:      tensor<1x52x52x32xf16> into tensor<1x52x52x4x8xf16>
+// CHECK:       util.return %[[EXP]]
+
+// -----
+
+// No transformation: all strides are 1.
+util.func public @no_transform_unit_strides(%src: tensor<32x25x25x2048xf16>) -> tensor<32x50x50x2048xf16> {
+  %cst = arith.constant dense<0.000000e+00> : tensor<32x50x50x2048xf16>
+  %0 = tensor.insert_slice %src into %cst[0, 1, 1, 0] [32, 25, 25, 2048] [1, 1, 1, 1] : tensor<32x25x25x2048xf16> into tensor<32x50x50x2048xf16>
+  util.return %0 : tensor<32x50x50x2048xf16>
+}
+
+// CHECK-LABEL: @no_transform_unit_strides
+// CHECK:       tensor.insert_slice
+// CHECK-NOT:   linalg.generic
+
+// -----
+
+// No transformation: destination is not a zero constant.
+util.func public @no_transform_nonzero_dest(%src: tensor<4x4xf32>, %dest: tensor<8x8xf32>) -> tensor<8x8xf32> {
+  %0 = tensor.insert_slice %src into %dest[0, 0] [4, 4] [2, 2] : tensor<4x4xf32> into tensor<8x8xf32>
+  util.return %0 : tensor<8x8xf32>
+}
+
+// CHECK-LABEL: @no_transform_nonzero_dest
+// CHECK:       tensor.insert_slice
+// CHECK-NOT:   linalg.generic
+
+// -----
+
+// No transformation: passthrough product (32*2048=65536) exceeds threshold.
+util.func public @no_transform_large_passthrough(%src: tensor<32x25x25x2048xf16>) -> tensor<32x50x50x2048xf16> {
+  %cst = arith.constant dense<0.000000e+00> : tensor<32x50x50x2048xf16>
+  %0 = tensor.insert_slice %src into %cst[0, 0, 0, 0] [32, 25, 25, 2048] [1, 2, 2, 1] : tensor<32x25x25x2048xf16> into tensor<32x50x50x2048xf16>
+  util.return %0 : tensor<32x50x50x2048xf16>
+}
+
+// CHECK-LABEL: @no_transform_large_passthrough
+// CHECK:       tensor.insert_slice
+// CHECK-NOT:   linalg.generic