[Preprocessing] Add pass to convert strided insert_slice to linalg.generic (#23990)
Add `iree-preprocessing-convert-strided-insert-slice-to-generic` pass
that converts strided `tensor.insert_slice` into zero-constant
destinations to a `linalg.generic` with index arithmetic. This targets
backward data convolution patterns where the upstream gradient is
scattered into a zero buffer at strided positions.
The replacement generic computes the strided scatter in a single
dispatch, replacing the `Memset + slow_memcpy` pair. Power-of-2 strides
use bitwise ops instead of expensive integer division. The result is
selected via `arith.select` for branchless GPU execution.
This PR is a partial implementation for
https://github.com/iree-org/iree/issues/23976.
---------
Signed-off-by: yzhang93 <zhyuhang88@gmail.com>
Co-authored-by: Claude <noreply@anthropic.com>
diff --git a/compiler/src/iree/compiler/Preprocessing/Common/BUILD.bazel b/compiler/src/iree/compiler/Preprocessing/Common/BUILD.bazel
index 29fb7bb..341a9bf 100644
--- a/compiler/src/iree/compiler/Preprocessing/Common/BUILD.bazel
+++ b/compiler/src/iree/compiler/Preprocessing/Common/BUILD.bazel
@@ -40,6 +40,7 @@
"Convert1X1FilterConv2DToMatmul.cpp",
"ConvertConvFilterToChannelsLast.cpp",
"ConvertConvToChannelsLast.cpp",
+ "ConvertStridedInsertSliceToGeneric.cpp",
"FoldAttentionWithTranspose.cpp",
"GeneralizeLinalgMatMul.cpp",
"Interpreter.cpp",
diff --git a/compiler/src/iree/compiler/Preprocessing/Common/CMakeLists.txt b/compiler/src/iree/compiler/Preprocessing/Common/CMakeLists.txt
index d2196ba..73148a7 100644
--- a/compiler/src/iree/compiler/Preprocessing/Common/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Preprocessing/Common/CMakeLists.txt
@@ -31,6 +31,7 @@
"Convert1X1FilterConv2DToMatmul.cpp"
"ConvertConvFilterToChannelsLast.cpp"
"ConvertConvToChannelsLast.cpp"
+ "ConvertStridedInsertSliceToGeneric.cpp"
"FoldAttentionWithTranspose.cpp"
"GeneralizeLinalgMatMul.cpp"
"Interpreter.cpp"
diff --git a/compiler/src/iree/compiler/Preprocessing/Common/ConvertStridedInsertSliceToGeneric.cpp b/compiler/src/iree/compiler/Preprocessing/Common/ConvertStridedInsertSliceToGeneric.cpp
new file mode 100644
index 0000000..5d0bbee
--- /dev/null
+++ b/compiler/src/iree/compiler/Preprocessing/Common/ConvertStridedInsertSliceToGeneric.cpp
@@ -0,0 +1,247 @@
+// Copyright 2026 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "iree/compiler/Preprocessing/Common/Passes.h"
+#include "llvm/ADT/TypeSwitch.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Linalg/IR/Linalg.h"
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/IR/Matchers.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/WalkPatternRewriteDriver.h"
+
+namespace mlir::iree_compiler::Preprocessing {
+
+#define GEN_PASS_DEF_CONVERTSTRIDEDINSERTSLICETOGENERICPASS
+#include "iree/compiler/Preprocessing/Common/Passes.h.inc"
+
+namespace {
+
+/// Convert a strided `tensor.insert_slice` into a zero-constant destination
+/// to a `linalg.generic` with index arithmetic.
+///
+/// This targets the backward data convolution pattern where the upstream
+/// gradient is scattered into a zero buffer at strided positions, producing
+/// a separate Memset + slow_memcpy dispatch pair. The replacement generic
+/// computes the strided scatter in a single pass and is potentially
+/// fusable with consumer ops in later dispatch formation passes.
+///
+/// For each output position, the generic checks whether the position maps
+/// to a valid source element (i.e., (pos - offset) is non-negative,
+/// divisible by stride, and the quotient is in-bounds). Source indices are
+/// clamped to valid range so the extract is always safe, and arith.select
+/// chooses between the extracted value and zero.
+class ConvertStridedInsertSliceToGeneric
+ : public OpRewritePattern<tensor::InsertSliceOp> {
+public:
+ using Base::Base;
+ LogicalResult matchAndRewrite(tensor::InsertSliceOp op,
+ PatternRewriter &rewriter) const override {
+ // Destination must be a zero splat constant.
+ Value dest = op.getDest();
+ Attribute destAttr;
+ if (!matchPattern(dest, m_Constant(&destAttr))) {
+ return failure();
+ }
+ auto splatAttr = dyn_cast<SplatElementsAttr>(destAttr);
+ if (!splatAttr) {
+ return failure();
+ }
+ Attribute splatVal = splatAttr.getSplatValue<Attribute>();
+ bool isZero =
+ TypeSwitch<Attribute, bool>(splatVal)
+ .Case<FloatAttr>([](auto a) { return a.getValue().isZero(); })
+ .Case<IntegerAttr>([](auto a) { return a.getValue().isZero(); })
+ .Default([](auto) { return false; });
+ if (!isZero) {
+ return failure();
+ }
+
+ // All offsets, sizes, and strides must be static, with at least one
+ // non-unit stride.
+ SmallVector<int64_t> offsets(op.getStaticOffsets());
+ SmallVector<int64_t> strides(op.getStaticStrides());
+ SmallVector<int64_t> sizes(op.getStaticSizes());
+ if (ShapedType::isDynamicShape(offsets) ||
+ ShapedType::isDynamicShape(strides) ||
+ ShapedType::isDynamicShape(sizes)) {
+ return failure();
+ }
+ if (llvm::all_of(strides, [](int64_t s) { return s == 1; })) {
+ return failure();
+ }
+
+ Value src = op.getSource();
+ auto srcTy = cast<RankedTensorType>(src.getType());
+ auto destTy = cast<RankedTensorType>(dest.getType());
+
+ unsigned origRank = destTy.getRank();
+ auto elemTy = destTy.getElementType();
+ Location loc = op.getLoc();
+
+ // A dim is "passthrough" if it has stride 1 and offset 0 — the scatter
+ // generic just copies it without any index arithmetic.
+ auto isPassthrough = [&](unsigned d) {
+ return strides[d] == 1 && offsets[d] == 0;
+ };
+
+ // Skip conversion when too many non-batch passthrough elements exist
+ // per spatial position. In grouped backward-data convolutions the
+ // group+channel dims are passthrough; when their product is large the
+ // scatter generic is slower than hardware DMA (Memset + slow_memcpy).
+ {
+ int64_t passthroughElems = 1;
+ constexpr int64_t kMaxPassthroughElems = 256;
+ for (unsigned i = 1; i < origRank; i++) {
+ if (isPassthrough(i)) {
+ passthroughElems *= destTy.getDimSize(i);
+ }
+ }
+ if (passthroughElems > kMaxPassthroughElems) {
+ return failure();
+ }
+ }
+
+ // Collapse contiguous passthrough dims to reduce iteration rank.
+ // E.g., grouped convs have trailing [G, C] dims that collapse to [G*C].
+ // Without this, high-rank scatter generics (5D+) can fail to compile
+ // in downstream passes and produce suboptimal GPU tiling.
+ SmallVector<ReassociationIndices> reassociation;
+ ReassociationIndices currentGroup = {0};
+ for (unsigned i = 1; i < origRank; i++) {
+ if (isPassthrough(i - 1) && isPassthrough(i)) {
+ currentGroup.push_back(i);
+ } else {
+ reassociation.push_back(std::move(currentGroup));
+ currentGroup = {static_cast<int64_t>(i)};
+ }
+ }
+ reassociation.push_back(std::move(currentGroup));
+ bool needsCollapse =
+ llvm::any_of(reassociation, [](const auto &g) { return g.size() > 1; });
+
+ // Compute collapsed metadata.
+ unsigned rank = reassociation.size();
+ SmallVector<int64_t> collapsedOffsets(rank), collapsedStrides(rank);
+ SmallVector<int64_t> collapsedDestShape(rank), collapsedSrcShape(rank);
+ for (auto [gi, group] : llvm::enumerate(reassociation)) {
+ collapsedOffsets[gi] = offsets[group[0]];
+ collapsedStrides[gi] = strides[group[0]];
+ int64_t destDim = 1, srcDim = 1;
+ for (int64_t idx : group) {
+ destDim *= destTy.getDimSize(idx);
+ srcDim *= srcTy.getDimSize(idx);
+ }
+ collapsedDestShape[gi] = destDim;
+ collapsedSrcShape[gi] = srcDim;
+ }
+ auto collapsedSrcTy = RankedTensorType::get(collapsedSrcShape, elemTy);
+ auto collapsedDestTy = RankedTensorType::get(collapsedDestShape, elemTy);
+
+ Value collapsedSrc = src;
+ if (needsCollapse) {
+ collapsedSrc = tensor::CollapseShapeOp::create(
+ rewriter, loc, collapsedSrcTy, src, reassociation);
+ }
+
+ // Build the scatter generic over the (collapsed) dest shape.
+ Value empty =
+ tensor::EmptyOp::create(rewriter, loc, collapsedDestShape, elemTy);
+ AffineMap identityMap = rewriter.getMultiDimIdentityMap(rank);
+ SmallVector<AffineMap> indexingMaps = {identityMap};
+ SmallVector<utils::IteratorType> iterTypes(rank,
+ utils::IteratorType::parallel);
+
+ auto genericOp = linalg::GenericOp::create(
+ rewriter, loc, collapsedDestTy, /*inputs=*/ValueRange{},
+ /*outputs=*/ValueRange{empty}, indexingMaps, iterTypes,
+ [&](OpBuilder &b, Location loc, ValueRange /*args*/) {
+ auto cstIdx = [&](int64_t v) {
+ return arith::ConstantOp::create(b, loc, b.getIndexAttr(v));
+ };
+
+ Value zero = cstIdx(0);
+ Value allValid = nullptr;
+ SmallVector<Value> srcIndices;
+ for (unsigned i = 0; i < rank; i++) {
+ Value idx = linalg::IndexOp::create(b, loc, i);
+
+ // Passthrough dims: source index == dest index.
+ if (collapsedStrides[i] == 1 && collapsedOffsets[i] == 0) {
+ srcIndices.push_back(idx);
+ continue;
+ }
+
+ int64_t stride = collapsedStrides[i];
+ int64_t srcSize = collapsedSrcTy.getDimSize(i);
+ Value shifted =
+ arith::SubIOp::create(b, loc, idx, cstIdx(collapsedOffsets[i]));
+
+ Value strVal = cstIdx(stride);
+ Value rem = arith::RemSIOp::create(b, loc, shifted, strVal);
+ Value srcIdx = arith::DivSIOp::create(b, loc, shifted, strVal);
+
+ // Valid when shifted >= 0 && rem == 0 && srcIdx < srcSize.
+ Value dimValid = arith::AndIOp::create(
+ b, loc,
+ arith::CmpIOp::create(b, loc, arith::CmpIPredicate::sge,
+ shifted, zero),
+ arith::CmpIOp::create(b, loc, arith::CmpIPredicate::eq, rem,
+ zero));
+ dimValid = arith::AndIOp::create(
+ b, loc, dimValid,
+ arith::CmpIOp::create(b, loc, arith::CmpIPredicate::slt, srcIdx,
+ cstIdx(srcSize)));
+ allValid = allValid
+ ? arith::AndIOp::create(b, loc, allValid, dimValid)
+ : dimValid;
+
+ // Clamp to [0, srcSize-1] so the extract is always in-bounds.
+ Value clamped = arith::MaxSIOp::create(b, loc, srcIdx, zero);
+ clamped =
+ arith::MinSIOp::create(b, loc, clamped, cstIdx(srcSize - 1));
+ srcIndices.push_back(clamped);
+ }
+
+ Value zeroElem =
+ arith::ConstantOp::create(b, loc, rewriter.getZeroAttr(elemTy));
+ Value extracted =
+ tensor::ExtractOp::create(b, loc, collapsedSrc, srcIndices);
+ Value result = allValid ? arith::SelectOp::create(b, loc, allValid,
+ extracted, zeroElem)
+ : extracted;
+ linalg::YieldOp::create(b, loc, result);
+ });
+
+ Value result = genericOp.getResult(0);
+ if (needsCollapse) {
+ result = tensor::ExpandShapeOp::create(rewriter, loc, destTy, result,
+ reassociation);
+ }
+ rewriter.replaceOp(op, result);
+ return success();
+ }
+};
+
+struct ConvertStridedInsertSliceToGenericPass
+ : impl::ConvertStridedInsertSliceToGenericPassBase<
+ ConvertStridedInsertSliceToGenericPass> {
+ void getDependentDialects(DialectRegistry ®istry) const override {
+ registry.insert<arith::ArithDialect, linalg::LinalgDialect,
+ tensor::TensorDialect>();
+ }
+
+ void runOnOperation() override {
+ MLIRContext *context = &getContext();
+ RewritePatternSet patterns(context);
+ patterns.insert<ConvertStridedInsertSliceToGeneric>(context);
+ walkAndApplyPatterns(getOperation(), std::move(patterns));
+ }
+};
+
+} // namespace
+} // namespace mlir::iree_compiler::Preprocessing
diff --git a/compiler/src/iree/compiler/Preprocessing/Common/Passes.td b/compiler/src/iree/compiler/Preprocessing/Common/Passes.td
index 6303e6e..f9e1290 100644
--- a/compiler/src/iree/compiler/Preprocessing/Common/Passes.td
+++ b/compiler/src/iree/compiler/Preprocessing/Common/Passes.td
@@ -78,6 +78,11 @@
];
}
+def ConvertStridedInsertSliceToGenericPass:
+ Pass<"iree-preprocessing-convert-strided-insert-slice-to-generic", ""> {
+ let summary = "Converts strided insert_slice into zero-constant destinations to linalg.generic with index arithmetic.";
+}
+
def FoldAttentionWithTransposePass :
Pass<"iree-preprocessing-fold-attention-with-transpose", ""> {
let summary = "Fold attention operation with transpose";
diff --git a/compiler/src/iree/compiler/Preprocessing/Common/test/BUILD.bazel b/compiler/src/iree/compiler/Preprocessing/Common/test/BUILD.bazel
index 38583f1..9ae24d7 100644
--- a/compiler/src/iree/compiler/Preprocessing/Common/test/BUILD.bazel
+++ b/compiler/src/iree/compiler/Preprocessing/Common/test/BUILD.bazel
@@ -21,6 +21,7 @@
"conv1x1_to_matmul.mlir",
"conv_filter_to_channels_last.mlir",
"conv_to_channels_last.mlir",
+ "convert_strided_insert_slice_to_generic.mlir",
"fold_attention_with_transpose.mlir",
"generalize_linalg_matmul.mlir",
"make_single_dispatch_for_function.mlir",
diff --git a/compiler/src/iree/compiler/Preprocessing/Common/test/CMakeLists.txt b/compiler/src/iree/compiler/Preprocessing/Common/test/CMakeLists.txt
index e947c57..75d31de 100644
--- a/compiler/src/iree/compiler/Preprocessing/Common/test/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Preprocessing/Common/test/CMakeLists.txt
@@ -18,6 +18,7 @@
"conv1x1_to_matmul.mlir"
"conv_filter_to_channels_last.mlir"
"conv_to_channels_last.mlir"
+ "convert_strided_insert_slice_to_generic.mlir"
"fold_attention_with_transpose.mlir"
"generalize_linalg_matmul.mlir"
"make_single_dispatch_for_function.mlir"
diff --git a/compiler/src/iree/compiler/Preprocessing/Common/test/convert_strided_insert_slice_to_generic.mlir b/compiler/src/iree/compiler/Preprocessing/Common/test/convert_strided_insert_slice_to_generic.mlir
new file mode 100644
index 0000000..9c0088f
--- /dev/null
+++ b/compiler/src/iree/compiler/Preprocessing/Common/test/convert_strided_insert_slice_to_generic.mlir
@@ -0,0 +1,134 @@
+// RUN: iree-opt --split-input-file --mlir-print-local-scope --iree-preprocessing-convert-strided-insert-slice-to-generic %s | FileCheck %s
+
+// Converted: stride-2 with non-zero offsets, no passthrough dims, no collapse.
+// Checks the index arithmetic: sub offset, rem/div for stride check, bounds check, clamp, extract, select.
+util.func public @stride2_no_passthrough(%src: tensor<4x4xf16>) -> tensor<9x9xf16> {
+ %cst = arith.constant dense<0.000000e+00> : tensor<9x9xf16>
+ %0 = tensor.insert_slice %src into %cst[1, 1] [4, 4] [2, 2] : tensor<4x4xf16> into tensor<9x9xf16>
+ util.return %0 : tensor<9x9xf16>
+}
+
+// CHECK-LABEL: @stride2_no_passthrough
+// CHECK-SAME: %[[SRC:.*]]: tensor<4x4xf16>
+// CHECK-NOT: tensor.insert_slice
+// CHECK: %[[GENERIC:.*]] = linalg.generic
+// CHECK-SAME: iterator_types = ["parallel", "parallel"]
+// CHECK-SAME: outs({{.*}} : tensor<9x9xf16>)
+// CHECK: linalg.index 0
+// CHECK: arith.subi
+// CHECK: arith.remsi
+// CHECK: arith.divsi
+// CHECK: linalg.index 1
+// CHECK: arith.subi
+// CHECK: arith.remsi
+// CHECK: arith.divsi
+// CHECK: tensor.extract %[[SRC]]
+// CHECK: arith.select
+// CHECK: linalg.yield
+// CHECK: util.return %[[GENERIC]]
+
+// -----
+
+// Converted: stride-3 with non-zero offsets.
+util.func public @stride3_no_passthrough(%src: tensor<3x3xf32>) -> tensor<10x10xf32> {
+ %cst = arith.constant dense<0.000000e+00> : tensor<10x10xf32>
+ %0 = tensor.insert_slice %src into %cst[1, 1] [3, 3] [3, 3] : tensor<3x3xf32> into tensor<10x10xf32>
+ util.return %0 : tensor<10x10xf32>
+}
+
+// CHECK-LABEL: @stride3_no_passthrough
+// CHECK-SAME: %[[SRC:.*]]: tensor<3x3xf32>
+// CHECK-NOT: tensor.insert_slice
+// CHECK: %[[GENERIC:.*]] = linalg.generic
+// CHECK-SAME: outs({{.*}} : tensor<10x10xf32>)
+// CHECK: linalg.index 0
+// CHECK: arith.subi
+// CHECK: arith.remsi
+// CHECK: arith.divsi
+// CHECK: linalg.index 1
+// CHECK: arith.subi
+// CHECK: arith.remsi
+// CHECK: arith.divsi
+// CHECK: tensor.extract %[[SRC]]
+// CHECK: arith.select
+// CHECK: util.return %[[GENERIC]]
+
+// -----
+
+// Converted with dim collapse: 5D input with passthrough trailing dims [3,4]
+// collapsed to 4D. Checks collapse_shape, 4D generic, and expand_shape.
+util.func public @stride2_with_collapse(%src: tensor<1x25x25x4x8xf16>) -> tensor<1x52x52x4x8xf16> {
+ %cst = arith.constant dense<0.000000e+00> : tensor<1x52x52x4x8xf16>
+ %0 = tensor.insert_slice %src into %cst[0, 1, 1, 0, 0] [1, 25, 25, 4, 8] [1, 2, 2, 1, 1] : tensor<1x25x25x4x8xf16> into tensor<1x52x52x4x8xf16>
+ util.return %0 : tensor<1x52x52x4x8xf16>
+}
+
+// CHECK-LABEL: @stride2_with_collapse
+// CHECK-SAME: %[[SRC:.*]]: tensor<1x25x25x4x8xf16>
+// CHECK-NOT: tensor.insert_slice
+// Source collapsed: dims [3,4] merged (4*8=32).
+// CHECK: %[[CSRC:.*]] = tensor.collapse_shape %[[SRC]]
+// CHECK-SAME: tensor<1x25x25x4x8xf16> into tensor<1x25x25x32xf16>
+// Generic at reduced rank 4.
+// CHECK: %[[GENERIC:.*]] = linalg.generic
+// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "parallel"]
+// CHECK-SAME: outs({{.*}} : tensor<1x52x52x32xf16>)
+// Dim 0 (batch): passthrough.
+// CHECK: linalg.index 0
+// Dim 1 (H): strided.
+// CHECK: linalg.index 1
+// CHECK: arith.subi
+// CHECK: arith.remsi
+// CHECK: arith.divsi
+// Dim 2 (W): strided.
+// CHECK: linalg.index 2
+// CHECK: arith.subi
+// CHECK: arith.remsi
+// CHECK: arith.divsi
+// Dim 3 (collapsed G*C): passthrough.
+// CHECK: linalg.index 3
+// CHECK: tensor.extract %[[CSRC]]
+// CHECK: arith.select
+// CHECK: linalg.yield
+// Result expanded back to 5D.
+// CHECK: %[[EXP:.*]] = tensor.expand_shape %[[GENERIC]]
+// CHECK-SAME: tensor<1x52x52x32xf16> into tensor<1x52x52x4x8xf16>
+// CHECK: util.return %[[EXP]]
+
+// -----
+
+// No transformation: all strides are 1.
+util.func public @no_transform_unit_strides(%src: tensor<32x25x25x2048xf16>) -> tensor<32x50x50x2048xf16> {
+ %cst = arith.constant dense<0.000000e+00> : tensor<32x50x50x2048xf16>
+ %0 = tensor.insert_slice %src into %cst[0, 1, 1, 0] [32, 25, 25, 2048] [1, 1, 1, 1] : tensor<32x25x25x2048xf16> into tensor<32x50x50x2048xf16>
+ util.return %0 : tensor<32x50x50x2048xf16>
+}
+
+// CHECK-LABEL: @no_transform_unit_strides
+// CHECK: tensor.insert_slice
+// CHECK-NOT: linalg.generic
+
+// -----
+
+// No transformation: destination is not a zero constant.
+util.func public @no_transform_nonzero_dest(%src: tensor<4x4xf32>, %dest: tensor<8x8xf32>) -> tensor<8x8xf32> {
+ %0 = tensor.insert_slice %src into %dest[0, 0] [4, 4] [2, 2] : tensor<4x4xf32> into tensor<8x8xf32>
+ util.return %0 : tensor<8x8xf32>
+}
+
+// CHECK-LABEL: @no_transform_nonzero_dest
+// CHECK: tensor.insert_slice
+// CHECK-NOT: linalg.generic
+
+// -----
+
+// No transformation: passthrough product (32*2048=65536) exceeds threshold.
+util.func public @no_transform_large_passthrough(%src: tensor<32x25x25x2048xf16>) -> tensor<32x50x50x2048xf16> {
+ %cst = arith.constant dense<0.000000e+00> : tensor<32x50x50x2048xf16>
+ %0 = tensor.insert_slice %src into %cst[0, 0, 0, 0] [32, 25, 25, 2048] [1, 2, 2, 1] : tensor<32x25x25x2048xf16> into tensor<32x50x50x2048xf16>
+ util.return %0 : tensor<32x50x50x2048xf16>
+}
+
+// CHECK-LABEL: @no_transform_large_passthrough
+// CHECK: tensor.insert_slice
+// CHECK-NOT: linalg.generic