Forward tensor.insert_slice coming from in_parallel lowering to flow.… (#8757)

* Forward tensor.insert_slice coming from in_parallel lowering to flow.dispatch.tensor_store

This pattern is currently necessary for correctness, it accounts for the fact that
InParallel is distributed across multiple workgroups when lowering to HAL
but the current implementation connects it to a sequential tensor.insert_slice
and only later to flow.dispatch.tensor_store.

In the future, all the rewrites in this file this should be done as part of InParallel
-> HAL rewrite. But because of dialect dependencies and layering, we have
some phase ordering that prevents it atm.

A similar layering issue prevents bufferization to be controlled by the transform dialect atm.
The effects of the insert_slice -> tensor_store forwarding are best observed before bufferization
removes the tensor_store op.
To allow more modular testing and separation of concerns, a temporary test-only flag is added to
disable bufferization and exhibit the proper forwarding behavior.

Finally, a previously failing integration test is now correct.

* Refactor and reuse the foldOffsetsSizesAndStrides helper function
diff --git a/build_tools/bazel_to_cmake/bazel_to_cmake_targets.py b/build_tools/bazel_to_cmake/bazel_to_cmake_targets.py
index c12ab0f..1454782 100644
--- a/build_tools/bazel_to_cmake/bazel_to_cmake_targets.py
+++ b/build_tools/bazel_to_cmake/bazel_to_cmake_targets.py
@@ -22,6 +22,9 @@
     "//llvm-external-projects/iree-dialects:IREELinalgExtPasses": [
         "IREELinalgExtPasses"
     ],
+    "//llvm-external-projects/iree-dialects:IREELinalgExtTransforms": [
+        "IREELinalgExtTransforms"
+    ],
     "@torch-mlir-dialects//:TorchMLIRTMTensorDialect": [
         "TorchMLIRTMTensorDialect"
     ],
diff --git a/iree/compiler/Codegen/Common/BUILD b/iree/compiler/Codegen/Common/BUILD
index d0f834c..ca53ca8 100644
--- a/iree/compiler/Codegen/Common/BUILD
+++ b/iree/compiler/Codegen/Common/BUILD
@@ -76,6 +76,7 @@
         "//iree/compiler/Dialect/Util/IR",
         "//llvm-external-projects/iree-dialects:IREELinalgExtDialect",
         "//llvm-external-projects/iree-dialects:IREELinalgExtPasses",
+        "//llvm-external-projects/iree-dialects:IREELinalgExtTransforms",
         "@llvm-project//llvm:Support",
         "@llvm-project//mlir:Affine",
         "@llvm-project//mlir:AffineUtils",
diff --git a/iree/compiler/Codegen/Common/BufferizeCopyOnlyDispatchesPass.cpp b/iree/compiler/Codegen/Common/BufferizeCopyOnlyDispatchesPass.cpp
index c36da26..2500fc7 100644
--- a/iree/compiler/Codegen/Common/BufferizeCopyOnlyDispatchesPass.cpp
+++ b/iree/compiler/Codegen/Common/BufferizeCopyOnlyDispatchesPass.cpp
@@ -13,6 +13,7 @@
 
 #include "iree/compiler/Codegen/PassDetail.h"
 #include "iree/compiler/Codegen/Passes.h"
+#include "iree/compiler/Codegen/Utils/Utils.h"
 #include "iree/compiler/Dialect/Flow/IR/FlowDialect.h"
 #include "iree/compiler/Dialect/Flow/IR/FlowOps.h"
 #include "mlir/Dialect/Affine/IR/AffineOps.h"
@@ -28,88 +29,6 @@
 namespace mlir {
 namespace iree_compiler {
 
-/// Helper function to create `AffineExpr` from `OpFoldResult`. If the
-/// `OpFoldResult` is a `Value`, creates a `AffineSymbolExpr` and appends it to
-/// `symbols`.
-static AffineExpr getAffineExpr(OpFoldResult ofr, SmallVector<Value> &symbols) {
-  if (auto attr = ofr.dyn_cast<Attribute>()) {
-    return getAffineConstantExpr(attr.cast<IntegerAttr>().getInt(),
-                                 attr.getContext());
-  }
-  Value v = ofr.get<Value>();
-  AffineExpr expr = getAffineSymbolExpr(symbols.size(), v.getContext());
-  symbols.push_back(v);
-  return expr;
-}
-/// Converts an `AffineExpr` to `OpFoldResult` by generating an `affine.apply`
-/// operation.
-static OpFoldResult getOpFoldResult(OpBuilder &builder, Location loc,
-                                    AffineExpr expr,
-                                    SmallVector<Value> &symbols) {
-  AffineMap m = AffineMap::get(0, symbols.size(), expr);
-  return applyMapToValues(builder, loc, m, symbols)[0];
-}
-
-/// Methods to build the Affine Expr for arithmetic operations.
-static AffineExpr add(AffineExpr expr, OpFoldResult ofr,
-                      SmallVector<Value> &symbols) {
-  return expr + getAffineExpr(ofr, symbols);
-}
-static AffineExpr add(OpFoldResult lhs, OpFoldResult rhs,
-                      SmallVector<Value> &symbols) {
-  return getAffineExpr(lhs, symbols) + getAffineExpr(rhs, symbols);
-}
-static AffineExpr mul(AffineExpr expr, OpFoldResult ofr,
-                      SmallVector<Value> &symbols) {
-  return expr * getAffineExpr(ofr, symbols);
-}
-static AffineExpr mul(OpFoldResult lhs, OpFoldResult rhs,
-                      SmallVector<Value> &symbols) {
-  return getAffineExpr(lhs, symbols) * getAffineExpr(rhs, symbols);
-}
-
-/// Returns the offsets to use when combining two operations that implement the
-/// `OffsetSizeAndStrideOpInterface`. Also checks that the strides are 1.
-static LogicalResult foldOffsetsSizesAndStrides(
-    PatternRewriter &rewriter, Location loc,
-    OffsetSizeAndStrideOpInterface producer,
-    OffsetSizeAndStrideOpInterface consumer,
-    SmallVector<OpFoldResult> &combinedOffsets,
-    SmallVector<OpFoldResult> &combinedSizes,
-    SmallVector<OpFoldResult> &combinedStrides) {
-  SmallVector<OpFoldResult> producerOffsets = producer.getMixedOffsets();
-  SmallVector<OpFoldResult> producerStrides = producer.getMixedStrides();
-  SmallVector<OpFoldResult> consumerOffsets = consumer.getMixedOffsets();
-  SmallVector<OpFoldResult> consumerStrides = consumer.getMixedStrides();
-  if (producerOffsets.size() != consumerOffsets.size()) {
-    return rewriter.notifyMatchFailure(
-        consumer,
-        "expected op and producer to have same number of offset values");
-  }
-
-  combinedOffsets.resize(producerOffsets.size());
-  combinedSizes.resize(producerOffsets.size());
-  combinedStrides.resize(producerOffsets.size());
-  for (auto i : llvm::seq<unsigned>(0, producerOffsets.size())) {
-    SmallVector<Value> offsetSymbols, strideSymbols;
-    // The combined offset is computed as
-    //    producer_offset + consumer_offset * producer_strides.
-    combinedOffsets[i] = getOpFoldResult(
-        rewriter, loc,
-        add(mul(consumerOffsets[i], producerStrides[i], offsetSymbols),
-            producerOffsets[i], offsetSymbols),
-        offsetSymbols);
-    // The combined stride is computed as
-    //    producer_stride * consumer_stride.
-    combinedStrides[i] = getOpFoldResult(
-        rewriter, loc,
-        mul(producerStrides[i], consumerStrides[i], strideSymbols),
-        strideSymbols);
-  }
-  combinedSizes = consumer.getMixedSizes();
-  return success();
-}
-
 /// Returns the `hal.interface.binding` a value comes from.
 static Optional<IREE::HAL::InterfaceBindingSubspanOp> getBindingSubspanOp(
     Value v) {
@@ -142,9 +61,11 @@
     if (!dispatchTensorLoadOp) return failure();
 
     SmallVector<OpFoldResult> offsets, sizes, strides;
+    // `tensor.extract_slice` (i.e. the producer) folds **into**
+    // `flow.dispatch.tensor.load1 (i.e. the consumer).
     if (failed(foldOffsetsSizesAndStrides(
-            rewriter, dispatchTensorLoadOp->getLoc(), dispatchTensorLoadOp,
-            extractSliceOp, offsets, sizes, strides))) {
+            rewriter, dispatchTensorLoadOp->getLoc(), extractSliceOp,
+            dispatchTensorLoadOp, offsets, sizes, strides))) {
       return failure();
     }
 
@@ -182,12 +103,11 @@
     }
 
     SmallVector<OpFoldResult> offsets, sizes, strides;
-    // Treat the `flow.dispatch.tensor.store` as the producer and the
-    // `tensor.insert_slice` as the consumer since that would be the case for
-    // the final subview created.
+    // `tensor.insert_slice` (i.e. the producer) folds **into**
+    // `flow.dispatch.tensor.store` (i.e. the consumer).
     if (failed(foldOffsetsSizesAndStrides(
-            rewriter, dispatchTensorStoreOp->getLoc(), dispatchTensorStoreOp,
-            insertSliceOp, offsets, sizes, strides))) {
+            rewriter, dispatchTensorStoreOp->getLoc(), insertSliceOp,
+            dispatchTensorStoreOp, offsets, sizes, strides))) {
       return failure();
     }
 
diff --git a/iree/compiler/Codegen/Common/CMakeLists.txt b/iree/compiler/Codegen/Common/CMakeLists.txt
index c093856..4f97df6 100644
--- a/iree/compiler/Codegen/Common/CMakeLists.txt
+++ b/iree/compiler/Codegen/Common/CMakeLists.txt
@@ -53,6 +53,7 @@
   DEPS
     IREELinalgExtDialect
     IREELinalgExtPasses
+    IREELinalgExtTransforms
     LLVMSupport
     MLIRAffine
     MLIRAffineUtils
diff --git a/iree/compiler/Codegen/Common/SetNumWorkgroupsFromLinalgExtPass.cpp b/iree/compiler/Codegen/Common/SetNumWorkgroupsFromLinalgExtPass.cpp
index dd4d27d..3bb38e9 100644
--- a/iree/compiler/Codegen/Common/SetNumWorkgroupsFromLinalgExtPass.cpp
+++ b/iree/compiler/Codegen/Common/SetNumWorkgroupsFromLinalgExtPass.cpp
@@ -6,10 +6,12 @@
 
 #include "iree-dialects/Dialect/LinalgExt/IR/LinalgExtDialect.h"
 #include "iree-dialects/Dialect/LinalgExt/IR/LinalgExtOps.h"
+#include "iree-dialects/Dialect/LinalgExt/Transforms/Utils.h"
 #include "iree/compiler/Codegen/Dialect/LoweringConfig.h"
 #include "iree/compiler/Codegen/PassDetail.h"
 #include "iree/compiler/Codegen/Passes.h"
 #include "iree/compiler/Codegen/Transforms/Transforms.h"
+#include "iree/compiler/Codegen/Utils/Utils.h"
 #include "iree/compiler/Dialect/Flow/IR/FlowOps.h"
 #include "iree/compiler/Dialect/HAL/IR/HALDialect.h"
 #include "iree/compiler/Dialect/HAL/IR/HALOps.h"
@@ -60,6 +62,49 @@
   }
 };
 
+/// Forward LinalgExt::InParallel -> Tensor::InsertSlice -> Flow::TensorStore.
+/// This pattern is necessary for correctness, it accounts for the fact that
+/// InParallel is distributed across multiple workgroups when lowering to HAL
+/// but it then connects to a sequential tensor.insert_slice and then to
+/// flow.dispatch.tensor_store.
+///
+// TODO: All the rewrites in this file this should be done as part of InParallel
+// -> HAL rewrite. But because of dialect dependencies and layering, we have
+// some phase ordering that prevents it atm.
+class ForwardInParallelResultToFlow
+    : public OpRewritePattern<IREE::Flow::DispatchTensorStoreOp> {
+ public:
+  using OpRewritePattern<IREE::Flow::DispatchTensorStoreOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(IREE::Flow::DispatchTensorStoreOp op,
+                                PatternRewriter &rewriter) const override {
+    auto insertSliceOp = op.value().getDefiningOp<tensor::InsertSliceOp>();
+    if (!insertSliceOp) return failure();
+
+    // TODO: this should be done as part of InParallel -> HAL rewrite.
+    // But because of dialect dependencies and layering, we have some phase
+    // ordering that prevents it atm. It does not make sense to move the pattern
+    // because of this temporary layering problem, so we just ignore the
+    // condition for now.
+    //
+    // auto inParallelOp =
+    //     insertSliceOp.source().getDefiningOp<IREE::LinalgExt::InParallelOp>();
+    // if (!inParallelOp) return failure();
+
+    SmallVector<OpFoldResult> offsets, sizes, strides;
+    // `tensor.insert_slice` (i.e. the producer) folds **into**
+    // `flow.dispatch.tensor.store` (i.e. the consumer).
+    if (failed(foldOffsetsSizesAndStrides(rewriter, op.getLoc(), insertSliceOp,
+                                          op, offsets, sizes, strides)))
+      return failure();
+    rewriter.replaceOpWithNewOp<IREE::Flow::DispatchTensorStoreOp>(
+        op, insertSliceOp.source(), op.target(), op.target_dims(), offsets,
+        sizes, strides);
+
+    return success();
+  }
+};
+
 }  // namespace
 
 void SetNumWorkgroupsFromLinalgExtPass::runOnOperation() {
@@ -77,10 +122,17 @@
                                      IREE::HAL::InterfaceWorkgroupCountOp>,
               OneToOneRewritePattern<HALReturnOp, IREE::HAL::ReturnOp>>(
           context);
-  if (failed(
-          applyPatternsAndFoldGreedily(module, std::move(oneToOneRewrites)))) {
+  if (failed(applyPatternsAndFoldGreedily(module, std::move(oneToOneRewrites))))
     return signalPassFailure();
-  }
+
+  // Perform forwarding patterns to bridge the tensor / flow gap.
+  // This is necessary for correctness.
+  // TODO: given existing bufferization tricks, this may trigger unnecessary
+  // copies that need to be further investigated.
+  RewritePatternSet forwardPatterns(context);
+  forwardPatterns.insert<ForwardInParallelResultToFlow>(context);
+  if (failed(applyPatternsAndFoldGreedily(module, std::move(forwardPatterns))))
+    return signalPassFailure();
 
   llvm::StringMap<IREE::HAL::ExecutableEntryPointOp> entryPoints =
       getAllEntryPoints(module);
@@ -116,8 +168,9 @@
     });
   }
 
-  // Apply post distribution canonicalization passes.
+  // Apply post-distribution canonicalization passes.
   RewritePatternSet canonicalization(context);
+  AffineApplyOp::getCanonicalizationPatterns(canonicalization, context);
   AffineMinOp::getCanonicalizationPatterns(canonicalization, context);
   populateAffineMinSCFCanonicalizationPattern(canonicalization);
   IREE::Flow::populateFlowDispatchCanonicalizationPatterns(canonicalization,
diff --git a/iree/compiler/Codegen/LLVMCPU/Passes.cpp b/iree/compiler/Codegen/LLVMCPU/Passes.cpp
index 7203608..2973bac 100644
--- a/iree/compiler/Codegen/LLVMCPU/Passes.cpp
+++ b/iree/compiler/Codegen/LLVMCPU/Passes.cpp
@@ -35,6 +35,13 @@
                    "before conversion to LLVM IR"),
     llvm::cl::init(false));
 
+// TODO: Remove this flag once we can call bufferize from the transform dialect.
+static llvm::cl::opt<bool> clDisableLinalgTransformInterpBufferization(
+    "linalg-transform-interp-disable-bufferization",
+    llvm::cl::desc("Disables bufferization when running the linalg transform "
+                   "interp pass (testing only)."),
+    llvm::cl::init(false));
+
 //===---------------------------------------------------------------------===//
 // Default allocation functions for CPU backend
 //===---------------------------------------------------------------------===//
@@ -428,6 +435,10 @@
   // Sets the number of workgroups using kFakeHAL op information.
   passManager.addPass(createSetNumWorkgroupsFromLinalgExtPass());
 
+  // TODO: Remove this flag and the code below once we can call bufferize from
+  // the transform dialect.
+  if (clDisableLinalgTransformInterpBufferization) return;
+
   OpPassManager &modulePM = passManager.nest<ModuleOp>();
   // Bufferize the dispatch.
   BufferizationOptions::AllocationFn allocationFn =
diff --git a/iree/compiler/Codegen/LLVMCPU/test/linalg_transform.mlir b/iree/compiler/Codegen/LLVMCPU/test/linalg_transform.mlir
index b4dc6f8..87c6cec 100644
--- a/iree/compiler/Codegen/LLVMCPU/test/linalg_transform.mlir
+++ b/iree/compiler/Codegen/LLVMCPU/test/linalg_transform.mlir
@@ -1,4 +1,4 @@
-// RUN: iree-opt %s  -pass-pipeline='hal.executable(hal.executable.variant(iree-llvmcpu-lower-executable-target))' --iree-codegen-use-linalg-transform-interp --linalg-transform-file-name=%p/linalg_transform_spec.mlir | FileCheck %s
+// RUN: iree-opt %s  -pass-pipeline='hal.executable(hal.executable.variant(iree-llvmcpu-lower-executable-target))' --iree-codegen-use-linalg-transform-interp --linalg-transform-interp-disable-bufferization --linalg-transform-file-name=%p/linalg_transform_spec.mlir | FileCheck %s
 
 #device_target_cpu = #hal.device.target<"cpu", {executable_targets = [#hal.executable.target<"llvm", "embedded-elf-x86_64", {cpu_features = "", data_layout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-f80:128-n8:16:32:64-S128", native_vector_size = 16 : index, target_triple = "x86_64-unknown-unknown-eabi-elf"}>]}>
 #executable_layout = #hal.executable.layout<push_constants = 0, sets = [#hal.descriptor_set.layout<0, bindings = [#hal.descriptor_set.binding<0, storage_buffer>, #hal.descriptor_set.binding<1, storage_buffer>, #hal.descriptor_set.binding<2, storage_buffer>]>]>
@@ -18,21 +18,25 @@
         %2 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) offset(%c0) alignment(64) : !flow.dispatch.tensor<readwrite:250x1020xf32>
         %3 = flow.dispatch.tensor.load %0, offsets = [0, 0], sizes = [250, 500], strides = [1, 1] : !flow.dispatch.tensor<readonly:250x500xf32> -> tensor<250x500xf32>
         %4 = flow.dispatch.tensor.load %1, offsets = [0, 0], sizes = [500, 1020], strides = [1, 1] : !flow.dispatch.tensor<readonly:500x1020xf32> -> tensor<500x1020xf32>
-        %5 = flow.dispatch.tensor.load %2, offsets = [0, 0], sizes = [250, 1020], strides = [1, 1] : !flow.dispatch.tensor<readwrite:250x1020xf32> -> tensor<250x1020xf32>
 
         // CHECK: hal.executable.entry_point public @pad_matmul_static_dispatch_0 ordinal(0) layout(#executable_layout) {
         // CHECK:   %[[C1:.*]] = arith.constant 1 : index
         // CHECK: %[[C125:.*]] = arith.constant 125 : index
         // CHECK: hal.return %[[C125]], %[[C1]], %[[C1]] : index, index, index
 
-        //  CHECK-NOT: flow
+        %50 = linalg.init_tensor [250, 1020] : tensor<250x1020xf32>
+        %cst = arith.constant 0.000000e+00 : f32
+        %5 = linalg.fill ins(%cst : f32) outs(%50 : tensor<250x1020xf32>) -> tensor<250x1020xf32>
+
         //  CHECK-NOT: iree_linalg_ext
         //      CHECK: %[[IDX:.*]] = hal.interface.workgroup.id[0] : index
         //      CHECK: %[[OFF:.*]] = affine.apply #[[$map0]]()[%[[IDX]]]
         //      CHECK:  %[[SZ:.*]] = affine.min #map1()[%[[IDX]]]
-        //      CHECK: subview {{.*}}[%[[OFF]]{{.*}}[%[[SZ]]
-        //      CHECK: subview {{.*}}[%[[OFF]]{{.*}}[%[[SZ]]
-        //      CHECK: matmul{{.*}}ins{{.*}}outs
+        //      CHECK: tensor.extract_slice {{.*}}[%[[OFF]]{{.*}}[%[[SZ]]
+        //      CHECK: tensor.extract_slice {{.*}}[%[[OFF]]{{.*}}[%[[SZ]]
+        //      CHECK:  %[[MM:.*]] = linalg.matmul{{.*}}ins{{.*}}outs
+        //      CHECK: %[[OFF2:.*]] = affine.apply #[[$map0]]()[%[[IDX]]]
+        //      CHECK: flow.dispatch.tensor.store %[[MM]], %{{.*}}, offsets = [%[[OFF2]]{{.*}} : tensor<?x1020xf32> -> !flow.dispatch.tensor<readwrite:250x1020xf32>
         //      CHECK: return
         %6 = linalg.matmul ins(%3, %4 : tensor<250x500xf32>, tensor<500x1020xf32>) outs(%5 : tensor<250x1020xf32>) -> tensor<250x1020xf32>
         flow.dispatch.tensor.store %6, %2, offsets = [0, 0], sizes = [250, 1020], strides = [1, 1] : tensor<250x1020xf32> -> !flow.dispatch.tensor<readwrite:250x1020xf32>
diff --git a/iree/compiler/Codegen/Utils/Utils.cpp b/iree/compiler/Codegen/Utils/Utils.cpp
index d3e6da9..1ec4e77 100644
--- a/iree/compiler/Codegen/Utils/Utils.cpp
+++ b/iree/compiler/Codegen/Utils/Utils.cpp
@@ -590,5 +590,87 @@
                std::function<linalg::ProcInfo(OpBuilder &, Location)>>()};
 }
 
+/// Helper function to create `AffineExpr` from `OpFoldResult`. If the
+/// `OpFoldResult` is a `Value`, creates a `AffineSymbolExpr` and appends it to
+/// `symbols`.
+static AffineExpr getAffineExpr(OpFoldResult ofr, SmallVector<Value> &symbols) {
+  if (auto attr = ofr.dyn_cast<Attribute>()) {
+    return getAffineConstantExpr(attr.cast<IntegerAttr>().getInt(),
+                                 attr.getContext());
+  }
+  Value v = ofr.get<Value>();
+  AffineExpr expr = getAffineSymbolExpr(symbols.size(), v.getContext());
+  symbols.push_back(v);
+  return expr;
+}
+/// Converts an `AffineExpr` to `OpFoldResult` by generating an `affine.apply`
+/// operation.
+static OpFoldResult getOpFoldResult(OpBuilder &builder, Location loc,
+                                    AffineExpr expr,
+                                    SmallVector<Value> &symbols) {
+  AffineMap m = AffineMap::get(0, symbols.size(), expr);
+  return applyMapToValues(builder, loc, m, symbols)[0];
+}
+
+/// Methods to build the Affine Expr for arithmetic operations.
+static AffineExpr add(AffineExpr expr, OpFoldResult ofr,
+                      SmallVector<Value> &symbols) {
+  return expr + getAffineExpr(ofr, symbols);
+}
+static AffineExpr add(OpFoldResult lhs, OpFoldResult rhs,
+                      SmallVector<Value> &symbols) {
+  return getAffineExpr(lhs, symbols) + getAffineExpr(rhs, symbols);
+}
+static AffineExpr mul(AffineExpr expr, OpFoldResult ofr,
+                      SmallVector<Value> &symbols) {
+  return expr * getAffineExpr(ofr, symbols);
+}
+static AffineExpr mul(OpFoldResult lhs, OpFoldResult rhs,
+                      SmallVector<Value> &symbols) {
+  return getAffineExpr(lhs, symbols) * getAffineExpr(rhs, symbols);
+}
+
+/// Returns the offsets, sizes and strides to use when combining two operations
+/// that implement the `OffsetSizeAndStrideOpInterface`.
+LogicalResult foldOffsetsSizesAndStrides(
+    PatternRewriter &rewriter, Location loc,
+    OffsetSizeAndStrideOpInterface producer,
+    OffsetSizeAndStrideOpInterface consumer,
+    SmallVector<OpFoldResult> &combinedOffsets,
+    SmallVector<OpFoldResult> &combinedSizes,
+    SmallVector<OpFoldResult> &combinedStrides) {
+  SmallVector<OpFoldResult> consumerOffsets = consumer.getMixedOffsets();
+  SmallVector<OpFoldResult> consumerStrides = consumer.getMixedStrides();
+  SmallVector<OpFoldResult> producerOffsets = producer.getMixedOffsets();
+  SmallVector<OpFoldResult> producerStrides = producer.getMixedStrides();
+  if (consumerOffsets.size() != producerOffsets.size()) {
+    return rewriter.notifyMatchFailure(
+        producer,
+        "expected op and consumer to have same number of offset values");
+  }
+
+  combinedOffsets.resize(consumerOffsets.size());
+  combinedSizes.resize(consumerOffsets.size());
+  combinedStrides.resize(consumerOffsets.size());
+  for (auto i : llvm::seq<unsigned>(0, consumerOffsets.size())) {
+    SmallVector<Value> offsetSymbols, strideSymbols;
+    // The combined offset is computed as
+    //    consumer_offset + producer_offset * consumer_strides.
+    combinedOffsets[i] = getOpFoldResult(
+        rewriter, loc,
+        add(mul(producerOffsets[i], consumerStrides[i], offsetSymbols),
+            consumerOffsets[i], offsetSymbols),
+        offsetSymbols);
+    // The combined stride is computed as
+    //    consumer_stride * producer_stride.
+    combinedStrides[i] = getOpFoldResult(
+        rewriter, loc,
+        mul(consumerStrides[i], producerStrides[i], strideSymbols),
+        strideSymbols);
+  }
+  combinedSizes = producer.getMixedSizes();
+  return success();
+}
+
 }  // namespace iree_compiler
 }  // namespace mlir
diff --git a/iree/compiler/Codegen/Utils/Utils.h b/iree/compiler/Codegen/Utils/Utils.h
index be5322e..4e4d06d 100644
--- a/iree/compiler/Codegen/Utils/Utils.h
+++ b/iree/compiler/Codegen/Utils/Utils.h
@@ -16,6 +16,9 @@
 #include "mlir/Dialect/Linalg/Utils/Utils.h"
 #include "mlir/Dialect/SCF/SCF.h"
 #include "mlir/IR/BuiltinOps.h"
+#include "mlir/IR/OpDefinition.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Interfaces/ViewLikeInterface.h"
 
 namespace mlir {
 namespace iree_compiler {
@@ -155,6 +158,23 @@
 /// Returns the option that distributes the ops using the flow workgroup
 /// ID/Count operations.
 linalg::LinalgLoopDistributionOptions getIREELinalgLoopDistributionOptions();
+
+/// Returns the `combinedOffsets`, `combinedSizes` and `combinedStrides` to use
+/// when folding a "producer" **into** a "consumer" op that implement
+/// `OffsetSizeAndStrideOpInterface`.
+/// The following computations are performed:
+///   - offsets = producer_offsets * consumer_strides + consumer_offsets,
+///   - sizes = producer_sizes
+///   - strides = producer_strides * consumer_strides.
+// TODO: Sizes should technically be combined with `min` but one often has
+// enough static knowledge to avoid this extra complexity.
+LogicalResult foldOffsetsSizesAndStrides(
+    PatternRewriter &rewriter, Location loc,
+    OffsetSizeAndStrideOpInterface producer,
+    OffsetSizeAndStrideOpInterface consumer,
+    SmallVector<OpFoldResult> &combinedOffsets,
+    SmallVector<OpFoldResult> &combinedSizes,
+    SmallVector<OpFoldResult> &combinedStrides);
 }  // namespace iree_compiler
 }  // namespace mlir
 
diff --git a/iree/test/e2e/linalg_transform/linalg_transform.mlir b/iree/test/e2e/linalg_transform/linalg_transform.mlir
index 2ac2952..eeda457 100644
--- a/iree/test/e2e/linalg_transform/linalg_transform.mlir
+++ b/iree/test/e2e/linalg_transform/linalg_transform.mlir
@@ -18,14 +18,9 @@
     [10.0, 09.0, 08.0, 07.0, 06.0],
     [05.0, 04.0, 03.0, 02.0, 01.0]]> : tensor<3x5xf32> -> tensor<3x5xf32>
 
-  // util.do_not_optimize on output to prevent fusing in the same dispatch
-  // region which would be subject to racy tensor semantics.
-  // Forcing different dispatches forces flow.dispatch.tensor.load which is
-  // actually side-effecting.
-  %res_in = util.do_not_optimize(%res) : tensor<5x5xf32>
   %matmul = linalg.matmul
       ins(%lhs, %rhs : tensor<5x3xf32>, tensor<3x5xf32>)
-      outs(%res_in : tensor<5x5xf32>) -> tensor<5x5xf32>
+      outs(%res : tensor<5x5xf32>) -> tensor<5x5xf32>
   %matmul_res = util.do_not_optimize(%matmul) : tensor<5x5xf32>
 
   return %matmul_res : tensor<5x5xf32>