Add a pass to rewrite destructive update patterns involving Linalg on tensors. (#4081)

This pass may be called after transformations on Linalg on tensors and includes special
detection and rewrite of destructive update patterns.
Such destructive updates are assumed to come from tiling of parallel iterators only.
supporting tiling across reduction iterators in the tensor domain will require additional machinery and possibly dependence analysis.
diff --git a/iree/compiler/Conversion/LinalgToLLVM/BUILD b/iree/compiler/Conversion/LinalgToLLVM/BUILD
index bb93b48..0ef31cb 100644
--- a/iree/compiler/Conversion/LinalgToLLVM/BUILD
+++ b/iree/compiler/Conversion/LinalgToLLVM/BUILD
@@ -24,6 +24,7 @@
         "ConvImg2ColMatmulConversion.cpp",
         "ConvertToLLVM.cpp",
         "KernelDispatch.cpp",
+        "LinalgRewriteDestructiveUpdatesPass.cpp",
         "LinalgTileAndDistributePass.cpp",
         "LinalgTileAndVectorizePass.cpp",
         "Passes.cpp",
@@ -54,9 +55,11 @@
         "@llvm-project//mlir:LinalgToLLVM",
         "@llvm-project//mlir:LinalgTransforms",
         "@llvm-project//mlir:Pass",
+        "@llvm-project//mlir:SCFDialect",
         "@llvm-project//mlir:StandardOps",
         "@llvm-project//mlir:StandardOpsTransforms",
         "@llvm-project//mlir:StandardToSPIRVConversions",
+        "@llvm-project//mlir:Support",
         "@llvm-project//mlir:Transforms",
         "@llvm-project//mlir:VectorOps",
         "@llvm-project//mlir:VectorToLLVM",
diff --git a/iree/compiler/Conversion/LinalgToLLVM/CMakeLists.txt b/iree/compiler/Conversion/LinalgToLLVM/CMakeLists.txt
index df9c39a..46d3787 100644
--- a/iree/compiler/Conversion/LinalgToLLVM/CMakeLists.txt
+++ b/iree/compiler/Conversion/LinalgToLLVM/CMakeLists.txt
@@ -24,6 +24,7 @@
     "ConvImg2ColMatmulConversion.cpp"
     "ConvertToLLVM.cpp"
     "KernelDispatch.cpp"
+    "LinalgRewriteDestructiveUpdatesPass.cpp"
     "LinalgTileAndDistributePass.cpp"
     "LinalgTileAndVectorizePass.cpp"
     "Passes.cpp"
@@ -37,11 +38,13 @@
     MLIRLinalgToLLVM
     MLIRLinalgTransforms
     MLIRPass
+    MLIRSCF
     MLIRSCFToStandard
     MLIRStandard
     MLIRStandardOpsTransforms
     MLIRStandardToLLVM
     MLIRStandardToSPIRVTransforms
+    MLIRSupport
     MLIRTransforms
     MLIRVector
     MLIRVectorToLLVM
diff --git a/iree/compiler/Conversion/LinalgToLLVM/LinalgRewriteDestructiveUpdatesPass.cpp b/iree/compiler/Conversion/LinalgToLLVM/LinalgRewriteDestructiveUpdatesPass.cpp
new file mode 100644
index 0000000..e4bd118
--- /dev/null
+++ b/iree/compiler/Conversion/LinalgToLLVM/LinalgRewriteDestructiveUpdatesPass.cpp
@@ -0,0 +1,547 @@
+// Copyright 2020 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//      https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+//===- LinalgRewriteDestructiveUpdates.cpp - Pass for destructive updates--===//
+//
+// Pass to rewrite Linalg on tensors destructive updates into updates through
+// memory.
+//
+//===----------------------------------------------------------------------===//
+
+#include "iree/compiler/Dialect/HAL/IR/HALOps.h"
+#include "iree/compiler/Dialect/IREE/IR/IREEDialect.h"
+#include "llvm/ADT/TypeSwitch.h"
+#include "llvm/Support/Debug.h"
+#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
+#include "mlir/Dialect/SCF/SCF.h"
+#include "mlir/IR/Dominance.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Support/LLVM.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+#define DEBUG_TYPE "iree-linalg-bufferize"
+
+namespace mlir {
+namespace iree_compiler {
+
+// Detect the pattern:
+// %d0 = for  (iter_args %e0 = %0)
+//   ...
+//   %dk = for ((iter_args %ek = %e{k-1}))
+//     ...
+//       %dn = destructive-update-op (%en)
+//       yield %dn
+//     ...
+//     yield %dk
+//   yield %dk
+struct SpecialTerminatorOpCapture {
+  Value initValue;
+  // For now, must be scf::ForOps.
+  SmallVector<Operation *, 4> loops;
+  // For now, must be a SubTensorInsertOp.
+  Operation *rootDestructiveUpdate;
+  bool readOnly = false;
+  bool writeOnly = false;
+};
+
+// TODO: Use some interface instead of op names directly.
+static bool hasDestructiveUpdateSubTensorUses(
+    Value v, SpecialTerminatorOpCapture &capture) {
+  SmallVector<SubTensorOp, 4> reads;
+  SmallVector<SubTensorInsertOp, 4> writes;
+  for (auto &u : v.getUses()) {
+    if (auto subTensorOp = dyn_cast<SubTensorOp>(u.getOwner())) {
+      reads.push_back(subTensorOp);
+      continue;
+    }
+    if (auto subTensorInsertOp = dyn_cast<SubTensorInsertOp>(u.getOwner())) {
+      writes.push_back(subTensorInsertOp);
+      continue;
+    }
+    LLVM_DEBUG(llvm::dbgs() << "found non-destructive update pattern use: "
+                            << *(u.getOwner()) << "\n");
+    return false;
+  }
+  // For now, only allow exactly a single SubTensorInsertOp that must be
+  // dominated by all SubTensorOp.
+  if (writes.size() != 1) return false;
+  // Small local dominance computation.
+  DominanceInfo domInfo(writes.front().getParentOp());
+  for (auto read : reads) {
+    LLVM_DEBUG(llvm::dbgs() << "read: " << *read.getOperation() << "\n");
+    if (!domInfo.properlyDominates(read.getOperation(), writes.front())) {
+      LLVM_DEBUG(llvm::dbgs()
+                 << "non-destructive use-def: " << *(read.getOperation())
+                 << " does not properly dominate "
+                 << *(writes.front().getOperation()) << "\n");
+      return false;
+    }
+  }
+
+  capture.readOnly = writes.empty();
+  capture.writeOnly = reads.empty();
+  capture.rootDestructiveUpdate = writes.front();
+  LLVM_DEBUG(llvm::dbgs() << "readOnly: " << capture.readOnly
+                          << " writeOnly: " << capture.writeOnly << "\n");
+  return true;
+}
+
+// Determine whether `tensor` is produced by a destructive update of another
+// tensor. When successful, fill a SpecialTerminatorOpCapture struct that
+// captures the relevant (distributed) pieces of IR that for the destructive
+// update pattern. There are 2 cases.
+//
+// Simple case
+// ===========
+//
+// Just detect a SubTensorInsertOp.
+//
+// Loop case
+// =========
+//
+// Iteratively traverse an (imperfectly nested) loop nest such as:
+//
+// ```
+// %d0 = for  (iter_args %e0 = %0)
+//   ...
+//   %dk = for ((iter_args %ek = %e{k-1}))
+//     ...
+//     %dn = destructive-update-op (%en)
+//     yield %dn
+//     ...
+//   yield %dk
+// ```
+//
+// to determine whether `d0` is produced by a scf::ForOp with destructive
+// update semantics.
+//
+// Return the value into which the destructive update occurs.
+// Return nullptr if `tensor` is not a destructive update of some other tensor
+// value.
+static Value isADestructiveUpdatePattern(Value tensor,
+                                         SpecialTerminatorOpCapture &capture) {
+  // Simple case: no loops and directly a tensorInsertOp.
+  if (auto tensorInsertOp =
+          dyn_cast_or_null<SubTensorInsertOp>(tensor.getDefiningOp())) {
+    capture.rootDestructiveUpdate = tensorInsertOp;
+    return tensorInsertOp.dest();
+  }
+
+  Value returnValue;
+  while (auto scfForOp = dyn_cast_or_null<scf::ForOp>(tensor.getDefiningOp())) {
+    LLVM_DEBUG(llvm::dbgs()
+               << "Step destructive update pattern: " << scfForOp << "\n");
+    // Capture the loop.
+    capture.loops.push_back(scfForOp);
+    // Analyze the iterArg at the proper position.
+    unsigned idx = tensor.cast<OpResult>().getResultNumber();
+    BlockArgument regionArg = *(scfForOp.getRegionIterArgs().begin() + idx);
+    // Set return value if not yet set.
+    if (!returnValue) returnValue = *(scfForOp.getIterOperands().begin() + idx);
+
+    // Case 1: zero use -> no destructive update.
+    if (regionArg.use_empty()) return nullptr;
+
+    // Case 2: multiple uses from an scf::ForOp then this must be used only by
+    // SubTensorOp / SubTensorInsertOp with proper dominance.
+    if (!regionArg.hasOneUse()) {
+      if (!hasDestructiveUpdateSubTensorUses(regionArg, capture))
+        return nullptr;
+      return returnValue;
+    }
+
+    assert(regionArg.hasOneUse());
+    LLVM_DEBUG(llvm::dbgs() << "one use analysis: " << regionArg << "\n");
+    OpOperand *operand = regionArg.getUses().begin().getOperand();
+    auto innerForOp = dyn_cast<scf::ForOp>(operand->getOwner());
+    // Case 3a: Single use which is not an scf::ForOp, it may still be a
+    // single SubTensor / SubTensorInsertOp.
+    if (!innerForOp) {
+      if (!hasDestructiveUpdateSubTensorUses(regionArg, capture))
+        return nullptr;
+      return returnValue;
+    }
+
+    // Case 3b: Single use which is an scf::ForOp: `innerIterArgIdx` is the
+    // candidate result and iterArg number.
+    unsigned innerIterArgIdx =
+        operand->getOperandNumber() - innerForOp.getNumControlOperands();
+    Value innerForOpResultTensor = innerForOp.getResult(innerIterArgIdx);
+    Value yieldValue =
+        scfForOp.region().front().getTerminator()->getOperand(idx);
+
+    // Check that the return position of dk and the yield position of dk
+    // agree (in the loop structure below). This avoids ping-pong effects
+    // between operands, yields and results.
+    //
+    // %d0 = for  (iter_args %e0 = %0)
+    //   ...
+    //   %dk = for ((iter_args %ek = %e{k-1}))
+    //     ...
+    //     %dn = destructive-update-op (%en)
+    //     yield %dn
+    //     ...
+    //   yield %dk
+    LLVM_DEBUG(llvm::dbgs()
+               << "innerForOpResultTensor: " << innerForOpResultTensor << "\n"
+               << "yieldValue: " << yieldValue << "\n"
+               << "step in: " << (innerForOpResultTensor == yieldValue)
+               << "\n");
+    if (innerForOpResultTensor != yieldValue) return nullptr;
+
+    // Prepare for the next level with the innerForOp's result at position
+    // `innerIterArgIdx`.
+    tensor = innerForOp.getResult(innerIterArgIdx);
+    LLVM_DEBUG(llvm::dbgs() << "next tensor: " << tensor << "\n");
+  }
+  return nullptr;
+}
+
+/// Convert `subtensor %t [offsets][sizes][strides] -> %st` to a
+/// hal.interface.load.tensor.tile.
+static LogicalResult propagateSubTensorOp(OpBuilder &b, SubTensorOp op) {
+  OpBuilder::InsertionGuard g(b);
+  b.setInsertionPoint(op);
+  auto loadOp = op.source().getDefiningOp<IREE::HAL::InterfaceLoadTensorOp>();
+  if (!loadOp) {
+    BlockArgument val = op.source().dyn_cast<BlockArgument>();
+    while (val) {
+      auto forOp = dyn_cast<scf::ForOp>(val.getOwner()->getParentOp());
+      // val is a block argument but not to an scf::ForOp -> bail.
+      if (!forOp) return failure();
+      unsigned idx = val.getArgNumber() - 1;  // accounting for IV arg.
+      Value iterOperand = *(forOp.getIterOperands().begin() + idx);
+      loadOp = iterOperand.getDefiningOp<IREE::HAL::InterfaceLoadTensorOp>();
+      val = iterOperand.dyn_cast<BlockArgument>();
+    }
+  }
+  if (!loadOp) return failure();
+
+  Value loaded = b.create<IREE::HAL::InterfaceLoadTensorTileOp>(
+      op.getLoc(), op.getResult().getType(), loadOp.binding(), loadOp.offset(),
+      op.offsets(), op.sizes(), op.strides(), op.static_offsets(),
+      op.static_sizes(), op.static_strides());
+  op.getResult().replaceAllUsesWith(loaded);
+  op.erase();
+  return success();
+}
+
+static LogicalResult rewriteSubTensorInsertInPlace(OpBuilder &b,
+                                                   SubTensorInsertOp op,
+                                                   SymbolRefAttr binding,
+                                                   Value offset) {
+  LLVM_DEBUG(llvm::dbgs() << "RewriteSubTensorInsertInPlace: "
+                          << *(op.getOperation()) << "\n");
+  OpBuilder::InsertionGuard g(b);
+  b.setInsertionPoint(op);
+  auto dest = op.dest();
+
+  // Sanity check for insert into an initArg that is immediately yielded.
+  if (!dest.isa<BlockArgument>() || !op.getResult().hasOneUse() ||
+      !isa<scf::YieldOp>(op.getResult().getUses().begin()->getOwner())) {
+    LLVM_DEBUG(llvm::dbgs() << "Rewrite failed: no single scf::YieldOp use\n");
+    return failure();
+  }
+
+  // Kills the SSA use-def chain.
+  op.replaceAllUsesWith(dest);
+  b.create<IREE::HAL::InterfaceStoreTensorTileOp>(
+      op.getLoc(), TypeRange{}, op.source(), binding, offset, op.offsets(),
+      op.sizes(), op.strides(), op.static_offsets(), op.static_sizes(),
+      op.static_strides());
+  return success();
+}
+
+// Return true if any control flow is found in the FuncOp besides scf::ForOp.
+static bool hasNonScfForControlFlow(FuncOp funcOp) {
+  return funcOp
+      .walk([&](Operation *op) {
+        if (isa<BranchOpInterface>(op) || isa<RegionBranchOpInterface>(op)) {
+          if (!isa<scf::ForOp>(op)) return WalkResult::interrupt();
+        }
+        return WalkResult::advance();
+      })
+      .wasInterrupted();
+}
+
+// Rewrite specific SubTensor / SubTensorInsert ops that match a "destructive
+// tensor update" pattern, by an inplace update at `binding` and `offset1, using
+// hal.interface.*.tensor.tile ops.
+// This serves as a step in jumping the abstraction gap between transformed
+// "linalg on tensors" IR and the buffer world.
+// This is possible because we control the production of such patterns in IREE
+// and can take the necessary shortcuts wrt inplace semantics.
+// In the future it is reasonable to expect special IR constructs to capture
+// some of the destructive update patterns,
+//
+// Assumptions/Invariants on "Control the Production of Such Patterns"
+// ===================================================================
+// 1. Input tensors may not participate in a destructive update pattern.
+// 2. Init and output tensors may participate in a destructive update pattern.
+// 3. No init or output tensor backing storage aliases with any other tensor
+//    storage.
+// 4. SubTensorOp/SubTensorInsertOp are the only ops that can extract/insert
+//    from/into tensors.
+// 5. All SubTensorOp/SubTensorInsertOp must have been introduced by Linalg
+//    tiling on tensors.
+// 6. Such tilings that result in yielded tensors across loops may only tile
+//    parallel Linalg iterators atm.
+// 7. (Future) Allow non-parallel Linalg iterators tiling and ensure first-read
+//    or writeOnly by construction.
+//
+// Note: the assumptions/invariants above are subject to changing ordering of
+// passes. When dispatch region and hal.interfaces are created on the linalg on
+// buffers path, these are all assumptions. In the future, when dispatch regions
+// and hal.interfaces are created post-transformations on the linalg on tensors
+// path some assumptions will become invariants.
+//
+// For now, the following destructive update patterns are rewritten.
+//
+// Coming from an `InterfaceLoadTensorOp`
+// ======================================
+// ```
+//   %0 = hal.interface.load.tensor @x[offsetx]
+//   ...
+//   %1 = destructive_update(%0)
+//   ...
+//   use_of(%1) // e.g. hal.interface.store.tensor %1 @y[offsety]
+// ```
+// is rewritten into:
+// ```
+//   %0 = hal.interface.load.tensor @x[offsetx]
+//   ...
+//   inplace_update @binding[offset]
+//   %2 = hal.interface.load.tensor  @binding[offset]
+//   ...
+//   use_of(%2) // e.g. hal.interface.store.tensor %2 @y[offsety]
+// ```
+//
+// This is a typical pattern that appears after tiling Linalg ops on tensors
+// with operands that come from hal.interface.
+//
+// Coming from a `LinalgOp`
+// =========================
+// ```
+//   %0 = linalg-op
+//   ...
+//   %1 = destructive_update(%0) // only subtensor_inserts into %0
+//   ...
+//   use_of(%1) // e.g. hal.interface.store.tensor %1 @y
+// ```
+// is rewritten into:
+// ```
+//   %0 = linalg-op
+//   ...
+//   inplace_update @binding[offset]
+//   %2 = hal.interface.load.tensor @binding[offset]
+//   ...
+//   hal.interface.store.tensor %2 @y[offsety]
+// ```
+// This is a typical pattern that appears after tileAndFuse ops with operands
+// produced by other linalg ops. In this case, tile and fuse leaves %0 behind
+// because it is the op that materializes the full tensor. This could be
+// replaced by a notional "tensor.undef" and the compute would become a dead
+// value.
+// The rewrite breaks the use-def chain for %0 and may result in the linalg-op
+// being DCE'd.
+//
+// Other rewrites:
+// ===============
+// Furthermore, when `@binding` == `@y` and `offset` == `offsety` and `...`
+// contains no aliasing read/write to either `@binding[offset]` or `@y[offsety]`
+// the following:
+// ```
+//   %2 = hal.interface.load.tensor @binding[offset]
+//   ...
+//   hal.interface.store.tensor %2 @y[offsety]
+// ```
+// is elided.
+// This should probably become a dedicated pass based on core alias analysis,
+// when the latter becomes available.
+static LogicalResult rewriteDestructiveUpdateInPlace(OpBuilder &b, Value v,
+                                                     SymbolRefAttr binding,
+                                                     Value offset) {
+  SpecialTerminatorOpCapture capture;
+  capture.initValue = v;
+  Value sourceValue = isADestructiveUpdatePattern(capture.initValue, capture);
+
+  // No destructive update semantics, bail.
+  if (!sourceValue || !capture.rootDestructiveUpdate) return failure();
+
+  Operation *outermostProducingOp = (capture.loops.empty())
+                                        ? capture.rootDestructiveUpdate
+                                        : capture.loops.front();
+  LLVM_DEBUG(llvm::dbgs() << "outermost producing: " << *outermostProducingOp
+                          << "\n");
+
+  // Helper to clone the load right after `outermostProducingOp` and use it in
+  // `op`.
+  auto buildAndUseLoad = [&](SymbolRefAttr bindingAttr, Value offsetValue) {
+    OpBuilder::InsertionGuard g(b);
+    b.setInsertionPointAfter(outermostProducingOp);
+    // TODO: no attributes other than `bindingAttr` on hal.interface.load.tensor
+    // atm. Do we need other attributes propagated?
+    Value newLoad = b.create<IREE::HAL::InterfaceLoadTensorOp>(
+        outermostProducingOp->getLoc(), v.getType(), bindingAttr, offsetValue);
+    // TODO: this brutally replaces all uses by the result of this load.
+    // In practice we may want more recompute and we may have lost information.
+    // Revisit this when things are morefleshed out.
+    v.replaceAllUsesWith(newLoad);
+  };
+
+  // `sourceValue`-specific determination of binding and offset for inplace
+  // update.
+  // TODO: when the destructively updated value comes from a LoadTensorOp, we
+  // need to decide whether to perform an update "into" the init tensor (i.e.
+  // the loadOp binding + offset) or directly into the output.
+  // In the fullness of time, this is dependent on whether or not the
+  // destructive update can guarantee that any particular subtensor of the
+  // result is updated only once.
+  // This is guaranteed for loops that come from tiling "parallel" Linalg
+  // iterators.
+  // Reduction iterators are subject to additional first-read/last-write
+  // considerations, usually derived from traditional memory-based dependence
+  // analysis.
+  // For now we assume that Linalg on tensors only tiles and fuses across
+  // parallel iterators, which allows reading the proper init value and updating
+  // the result.
+  SymbolRefAttr bindingAttr =
+      TypeSwitch<Operation *, SymbolRefAttr>(sourceValue.getDefiningOp())
+          .Case<IREE::HAL::InterfaceLoadTensorOp>(
+              // [&](auto op) { return op.binding(); })
+              [&](auto op) { return binding; })
+          .Case<linalg::LinalgOp>([&](auto op) { return binding; })
+          .Default([](Operation *) { return nullptr; });
+  Value offsetValue =
+      TypeSwitch<Operation *, Value>(sourceValue.getDefiningOp())
+          .Case<IREE::HAL::InterfaceLoadTensorOp>(
+              // [&](auto op) { return op.offset(); })
+              [&](auto op) { return offset; })
+          .Case<linalg::LinalgOp>([&](auto op) { return offset; })
+          .Default([](Operation *) { return nullptr; });
+
+  // TODO: support more cases as needed.
+  if (!bindingAttr) return failure();
+
+  // Try to rewrite inplace.
+  if (failed(rewriteSubTensorInsertInPlace(
+          b, cast<SubTensorInsertOp>(capture.rootDestructiveUpdate),
+          bindingAttr, offsetValue)))
+    return failure();
+
+  // Reload the value produced inplace right after the inplace update.
+  buildAndUseLoad(bindingAttr, offsetValue);
+
+  if (scf::ForOp loopOp = dyn_cast<scf::ForOp>(outermostProducingOp))
+    loopOp.walk([&](SubTensorOp op) { propagateSubTensorOp(b, op); });
+
+  return success();
+}
+
+// TODO: generalize to more than naive "top of the function consecutive ops".
+// Probably better to wait until core alias analysis is upstreamed.
+// TODO: interfaces.
+static bool hasInterleavedAliases(IREE::HAL::InterfaceLoadTensorOp loadOp,
+                                  IREE::HAL::InterfaceStoreTensorOp storeOp) {
+  Block *bLoad = loadOp.getOperation()->getBlock();
+  Block *bStore = loadOp.getOperation()->getBlock();
+  if (!isa<FuncOp>(bLoad->getParentOp()) ||
+      !isa<FuncOp>(bStore->getParentOp()) ||
+      bLoad->getParentOp() != bStore->getParentOp())
+    return true;
+
+  if (storeOp.getOperation()->getPrevNode() != loadOp) return true;
+
+  return false;
+}
+
+namespace {
+struct LinalgRewriteDestructiveUpdates
+    : public PassWrapper<LinalgRewriteDestructiveUpdates, FunctionPass> {
+  void getDependentDialects(DialectRegistry &registry) const override {
+    registry.insert<IREEDialect, linalg::LinalgDialect, scf::SCFDialect,
+                    StandardOpsDialect>();
+  }
+  void runOnFunction() override;
+};
+}  // namespace
+
+void LinalgRewriteDestructiveUpdates::runOnFunction() {
+  FuncOp funcOp = getFunction();
+
+  // Bail on any control-flow for now.
+  if (hasNonScfForControlFlow(funcOp)) return signalPassFailure();
+
+  MLIRContext *context = &getContext();
+  OpBuilder b(context);
+  // For each tensor store op, look for destructive updates and replace the
+  // destructive pattern by a custom inplace update pattern.
+  funcOp.walk([&](IREE::HAL::InterfaceStoreTensorOp op) {
+    if (failed(rewriteDestructiveUpdateInPlace(b, op.operand(), op.binding(),
+                                               op.offset()))) {
+      signalPassFailure();
+      return WalkResult::interrupt();
+    }
+    return WalkResult::advance();
+  });
+
+  // For each tensor store op, redundant load/store optimization.
+  funcOp.walk([&](IREE::HAL::InterfaceStoreTensorOp storeOp) {
+    auto loadOp = dyn_cast_or_null<IREE::HAL::InterfaceLoadTensorOp>(
+        storeOp.operand().getDefiningOp());
+
+    // Bail if there exists an interleaved aliasing.
+    if (!loadOp || hasInterleavedAliases(loadOp, storeOp)) return;
+
+    // Bail if this is not a simple forwarding.
+    // TODO: Handle more advanced forwarding, but we may need to do it
+    // earlier while we still have SSA use-def chains.
+    // I.e. revisit later cases such as:
+    // ```
+    //   inplace_update_tiles_rooted_at @legacy_io::@arg0, offset = %c0
+    //   %2 = hal.interface.load.tensor @legacy_io::@arg0, offset = %c0 :
+    //     tensor<2x4xf32>
+    //   hal.interface.store.tensor %2, @legacy_io::@ret0, offset = %c0
+    //     {operand_result_index = 3 : i32} : tensor<2x4xf32>
+    //   return
+    // ```
+    // where the inplace update could be done directly in @legacy_io::@ret0,
+    // offset = %c0.
+    if (loadOp.binding() != storeOp.binding() ||
+        loadOp.offset() != storeOp.offset())
+      return;
+
+    storeOp.erase();
+  });
+
+  // Non-default canonicalization patterns.
+  // TODO: add Linalg tiling canonicalization patterns, affineminscf and others
+  // as needed.
+  OwningRewritePatternList canonicalizationPatterns;
+  scf::ForOp::getCanonicalizationPatterns(canonicalizationPatterns, context);
+  applyPatternsAndFoldGreedily(funcOp, std::move(canonicalizationPatterns));
+}
+
+std::unique_ptr<OperationPass<FuncOp>>
+createLinalgRewriteDestructiveUpdatesPass() {
+  return std::make_unique<LinalgRewriteDestructiveUpdates>();
+}
+
+static PassRegistration<LinalgRewriteDestructiveUpdates> pass(
+    "iree-codegen-linalg-rewrite-destructive-updates",
+    "Test the rewrite of destructive update patterns to inplace update form.",
+    [] { return std::make_unique<LinalgRewriteDestructiveUpdates>(); });
+}  // namespace iree_compiler
+}  // namespace mlir
diff --git a/iree/compiler/Conversion/LinalgToLLVM/Passes.h b/iree/compiler/Conversion/LinalgToLLVM/Passes.h
index 0a325dc..09edd29 100644
--- a/iree/compiler/Conversion/LinalgToLLVM/Passes.h
+++ b/iree/compiler/Conversion/LinalgToLLVM/Passes.h
@@ -45,6 +45,11 @@
 /// Performs the final conversion to LLVM dialect.
 std::unique_ptr<OperationPass<ModuleOp>> createConvertToLLVMPass();
 
+/// Pass to rewrite Linalg on tensors destructive updates into updates through
+/// memory.
+std::unique_ptr<OperationPass<FuncOp>>
+createLinalgRewriteDestructiveUpdatesPass();
+
 /// Populates passes needed to lower a XLA HLO op to LLVM dialect via the
 /// structured ops path. The pass manager `pm` in here should operate on the
 /// module within the IREE::HAL::ExecutableOp.
diff --git a/iree/compiler/Conversion/LinalgToLLVM/test/linalg_rewrite_destructive_updates.mlir b/iree/compiler/Conversion/LinalgToLLVM/test/linalg_rewrite_destructive_updates.mlir
new file mode 100644
index 0000000..9360101
--- /dev/null
+++ b/iree/compiler/Conversion/LinalgToLLVM/test/linalg_rewrite_destructive_updates.mlir
@@ -0,0 +1,217 @@
+// RUN: iree-opt %s --iree-codegen-linalg-rewrite-destructive-updates -canonicalize -cse -split-input-file | IreeFileCheck %s
+
+// CHECK-LABEL: func @tile_from_tensor_load
+func @tile_from_tensor_load() {
+  %c0 = constant 0 : index
+  %c2 = constant 2 : index
+  %c4 = constant 4 : index
+  %c1 = constant 1 : index
+  %0 = hal.interface.load.tensor @legacy_io::@TENSOR_LHS, offset = %c0
+    {operand_result_index = 0 : i32} : tensor<2x3xf32>
+  %1 = hal.interface.load.tensor @legacy_io::@TENSOR_RHS, offset = %c0
+    {operand_result_index = 1 : i32} : tensor<3x4xf32>
+  %2 = hal.interface.load.tensor @legacy_io::@TENSOR_INIT, offset = %c0
+    {operand_result_index = 2 : i32} : tensor<2x4xf32>
+
+  %3 = iree.workgroup_id {dimension = "x"} : index
+  %4 = iree.workgroup_id {dimension = "y"} : index
+  // Test step %{{[0-9a-z]}} { to ensure yield has been folded away.
+  //      CHECK: scf.for %[[I:.*]] = {{.*}} step %{{[0-9a-z]+}} {
+  //      CHECK:   scf.for %[[J:.*]] = {{.*}} step %{{[0-9a-z]+}} {
+  //  CHECK-DAG:     %[[LHS:.*]] = hal.interface.load.tensor.tile @legacy_io::@TENSOR_LHS, base_offset = %{{[0-9a-z]*}},
+  // CHECK-SAME:       offsets = [%[[I]], 0], sizes = [1, 3], strides = [1, 1] : tensor<1x3xf32>
+  //  CHECK-DAG:     %[[RHS:.*]] = hal.interface.load.tensor.tile @legacy_io::@TENSOR_RHS, base_offset = %{{[0-9a-z]*}},
+  // CHECK-SAME:       offsets = [0, %[[J]]], sizes = [3, 1], strides = [1, 1] : tensor<3x1xf32>
+  //  CHECK-DAG:     %[[OUT:.*]] = hal.interface.load.tensor.tile @legacy_io::@TENSOR_INIT, base_offset = %{{[0-9a-z]*}},
+  // CHECK-SAME:       offsets = [%[[I]], %[[J]]], sizes = [1, 1], strides = [1, 1] : tensor<1x1xf32>
+  // Compute.
+  //      CHECK:     %[[RES:.*]] = linalg.matmul ins(%[[LHS]], %[[RHS]] : tensor<1x3xf32>, tensor<3x1xf32>)
+  // CHECK-SAME:                       init(%[[OUT]] : tensor<1x1xf32>)
+  // Inplace update (assume only parallel loops have been tiled).
+  //      CHECK:     hal.interface.store.tensor.tile %[[RES]], @legacy_io::@ret0, base_offset = %{{[0-9a-z]*}},
+  // CHECK-SAME:       offsets = [%[[I]], %[[J]]], sizes = [%c1, %c1], strides = [%c1, %c1] : tensor<1x1xf32>
+  %5 = scf.for %arg0 = %4 to %c2 step %c2 iter_args(%arg1 = %2) -> (tensor<2x4xf32>) {
+    %6 = scf.for %arg2 = %3 to %c4 step %c4 iter_args(%arg3 = %arg1) -> (tensor<2x4xf32>) {
+      %7 = subtensor %0[%arg0, 0] [1, 3] [1, 1] : tensor<2x3xf32> to tensor<1x3xf32>
+      %8 = subtensor %1[0, %arg2] [3, 1] [1, 1] : tensor<3x4xf32> to tensor<3x1xf32>
+      %9 = subtensor %arg3[%arg0, %arg2] [1, 1] [1, 1] : tensor<2x4xf32> to tensor<1x1xf32>
+      %10 = linalg.matmul ins(%7, %8 : tensor<1x3xf32>, tensor<3x1xf32>)
+                         init(%9 : tensor<1x1xf32>)
+        -> tensor<1x1xf32>
+      // This op is the destructive update we seek to eliminate.
+      %11 = subtensor_insert %10 into %arg3[%arg0, %arg2] [%c1, %c1] [%c1, %c1] :
+        tensor<1x1xf32> into tensor<2x4xf32>
+      scf.yield %11 : tensor<2x4xf32>
+    }
+    scf.yield %6 : tensor<2x4xf32>
+  }
+
+  // Under parallel iterators tiling assumptions, simple forwarding occurs.
+  //  CHECK-NOT: hal.interface.load.tensor %
+  //  CHECK-NOT: hal.interface.store.tensor %
+   // This is the store onto which the destructive update rewrites latches.
+  hal.interface.store.tensor %5, @legacy_io::@ret0, offset = %c0
+    {operand_result_index = 3 : i32} : tensor<2x4xf32>
+
+  return
+}
+
+// -----
+
+#accesses = [
+  affine_map<(i, j) -> (i, j)>,
+  affine_map<(i, j) -> (i, j)>
+]
+#trait = {
+  indexing_maps = #accesses,
+  iterator_types = ["parallel", "parallel"]
+}
+
+// CHECK-LABEL: func @tile_from_pointwise_lhs
+func @tile_from_pointwise_lhs() {
+  %c0 = constant 0 : index
+  %c2 = constant 2 : index
+  %c4 = constant 4 : index
+  %c1 = constant 1 : index
+  %0 = hal.interface.load.tensor @legacy_io::@TENSOR_INIT, offset = %c0
+    {operand_result_index = 0 : i32} : tensor<2x4xf32>
+  // This is the `lhs` operand that gets through a pointwise generic and becomes
+  // the first input operand of linalg.matmul. After tiling and fusion this is a
+  // subtensor of the original `hal.interface.load.tensor`.
+  %1 = hal.interface.load.tensor @legacy_io::@TENSOR_LHS, offset = %c0
+    {operand_result_index = 1 : i32} : tensor<2x3xf32>
+  %2 = hal.interface.load.tensor @legacy_io::@TENSOR_RHS, offset = %c0
+    {operand_result_index = 2 : i32} : tensor<3x4xf32>
+
+  %4 = iree.workgroup_id {dimension = "x"} : index
+  %5 = iree.workgroup_id {dimension = "y"} : index
+  // Test step %{{[0-9a-z]}} { to ensure yield has been folded away.
+  //      CHECK: scf.for %[[I:.*]] = {{.*}} step %{{[0-9a-z]+}} {
+  //      CHECK:   scf.for %[[J:.*]] = {{.*}} step %{{[0-9a-z]+}} {
+  //      CHECK:     %[[LHS:.*]] = hal.interface.load.tensor.tile @legacy_io::@TENSOR_LHS, base_offset = %{{[0-9a-z]*}},
+  // CHECK-SAME:       offsets = [%[[I]], 0], sizes = [1, 3], strides = [1, 1] : tensor<1x3xf32>
+  //      CHECK:     %[[RHS:.*]] = hal.interface.load.tensor.tile @legacy_io::@TENSOR_RHS, base_offset = %{{[0-9a-z]*}},
+  // CHECK-SAME:       offsets = [0, %[[J]]], sizes = [3, 1], strides = [1, 1] : tensor<3x1xf32>
+  // Compute.
+  //      CHECK:     %[[LHS2:.*]] = linalg.generic {{.*}} ins(%[[LHS]] : tensor<1x3xf32>) {
+  // TODO: should we reorder this load?
+  //      CHECK:     %[[OUT:.*]] = hal.interface.load.tensor.tile @legacy_io::@TENSOR_INIT, base_offset = %{{[0-9a-z]*}},
+  // CHECK-SAME:       offsets = [%[[I]], %[[J]]], sizes = [1, 1], strides = [1, 1] : tensor<1x1xf32>
+  // Compute.
+  //      CHECK:     %[[RES:.*]] = linalg.matmul ins(%[[LHS2]], %[[RHS]] : tensor<1x3xf32>, tensor<3x1xf32>)
+  // CHECK-SAME:                       init(%[[OUT]] : tensor<1x1xf32>)
+  // Inplace update (assume only parallel loops have been tiled).
+  //      CHECK:     hal.interface.store.tensor.tile %[[RES]], @legacy_io::@ret0, base_offset = %{{[0-9a-z]*}},
+  // CHECK-SAME:       offsets = [%[[I]], %[[J]]], sizes = [%c1, %c1], strides = [%c1, %c1] : tensor<1x1xf32>
+  %6 = scf.for %arg0 = %5 to %c2 step %c2 iter_args(%arg1 = %0) -> (tensor<2x4xf32>) {
+    %7 = scf.for %arg2 = %4 to %c4 step %c4 iter_args(%arg3 = %arg1) -> (tensor<2x4xf32>) {
+      %8 = subtensor %1[%arg0, 0] [1, 3] [1, 1] : tensor<2x3xf32> to tensor<1x3xf32>
+      %9 = subtensor %2[0, %arg2] [3, 1] [1, 1] : tensor<3x4xf32> to tensor<3x1xf32>
+      %10 = linalg.generic #trait ins(%8 : tensor<1x3xf32>) {
+      ^bb0(%arg4: f32):
+        linalg.yield %arg4 : f32
+      } -> tensor<1x3xf32>
+
+      %11 = subtensor %arg3[%arg0, %arg2] [1, 1] [1, 1] : tensor<2x4xf32> to tensor<1x1xf32>
+      %13 = linalg.matmul ins(%10, %9 : tensor<1x3xf32>, tensor<3x1xf32>)
+                         init(%11 : tensor<1x1xf32>) -> tensor<1x1xf32>
+
+      // This op is the destructive update we seek to eliminate.
+      %14 = subtensor_insert %13 into %arg3[%arg0, %arg2] [%c1, %c1] [%c1, %c1] :
+        tensor<1x1xf32> into tensor<2x4xf32>
+      scf.yield %14 : tensor<2x4xf32>
+    }
+    scf.yield %7 : tensor<2x4xf32>
+  }
+
+  // Under parallel iterators tiling assumptions, simple forwarding occurs.
+  //  CHECK-NOT: hal.interface.load.tensor %
+  //  CHECK-NOT: hal.interface.store.tensor %
+  // This is the store onto which the destructive update rewrites latches.
+  hal.interface.store.tensor %6, @legacy_io::@ret0, offset = %c0
+    {operand_result_index = 3 : i32} : tensor<2x4xf32>
+
+  return
+}
+
+// -----
+
+#accesses = [
+  affine_map<(i, j) -> (i, j)>,
+  affine_map<(i, j) -> (i, j)>
+]
+#trait = {
+  indexing_maps = #accesses,
+  iterator_types = ["parallel", "parallel"]
+}
+
+// CHECK-LABEL: func @tile_from_pointwise_init
+func @tile_from_pointwise_init() {
+  %c0 = constant 0 : index
+  %c2 = constant 2 : index
+  %c4 = constant 4 : index
+  %c1 = constant 1 : index
+  // This is the `out` operand that gets through a pointwise generic and becomes
+  // the init operand of linalg.matmul. After tiling and fusion this is a
+  // subtensor of the value produced by linalg.generic.
+  %0 = hal.interface.load.tensor @legacy_io::@TENSOR_INIT, offset = %c0
+    {operand_result_index = 0 : i32} : tensor<2x4xf32>
+  %1 = hal.interface.load.tensor @legacy_io::@TENSOR_LHS, offset = %c0
+    {operand_result_index = 1 : i32} : tensor<2x3xf32>
+  %2 = hal.interface.load.tensor @legacy_io::@TENSOR_RHS, offset = %c0
+    {operand_result_index = 2 : i32} : tensor<3x4xf32>
+
+  // This op is left over and should be DCE'ed but its result is currently used
+  // for a destructive update.
+  %3 = linalg.generic #trait ins(%0 : tensor<2x4xf32>) {
+  ^bb0(%arg0: f32):
+    linalg.yield %arg0 : f32
+  } -> tensor<2x4xf32>
+
+  %4 = iree.workgroup_id {dimension = "x"} : index
+  %5 = iree.workgroup_id {dimension = "y"} : index
+
+  // Test step %{{[0-9a-z]}} { to ensure yield has been folded away.
+  //      CHECK: scf.for %[[I:.*]] = {{.*}} step %{{[0-9a-z]+}} {
+  //      CHECK:   scf.for %[[J:.*]] = {{.*}} step %{{[0-9a-z]+}} {
+  //      CHECK:     %[[LHS:.*]] = hal.interface.load.tensor.tile @legacy_io::@TENSOR_LHS, base_offset = %{{[0-9a-z]*}},
+  // CHECK-SAME:       offsets = [%[[I]], 0], sizes = [1, 3], strides = [1, 1] : tensor<1x3xf32>
+  //      CHECK:     %[[RHS:.*]] = hal.interface.load.tensor.tile @legacy_io::@TENSOR_RHS, base_offset = %{{[0-9a-z]*}},
+  // CHECK-SAME:       offsets = [0, %[[J]]], sizes = [3, 1], strides = [1, 1] : tensor<3x1xf32>
+  //      CHECK:     %[[OUT:.*]] = hal.interface.load.tensor.tile @legacy_io::@TENSOR_INIT, base_offset = %{{[0-9a-z]*}},
+  // CHECK-SAME:       offsets = [%[[I]], %[[J]]], sizes = [1, 1], strides = [1, 1] : tensor<1x1xf32>
+  // Compute.
+  //      CHECK:     %[[OUT2:.*]] = linalg.generic {{.*}} ins(%[[OUT]] : tensor<1x1xf32>) {
+  //      CHECK:     %[[RES:.*]] = linalg.matmul ins(%[[LHS]], %[[RHS]] : tensor<1x3xf32>, tensor<3x1xf32>)
+  // CHECK-SAME:                       init(%[[OUT2]] : tensor<1x1xf32>)
+  // Inplace update (assume only parallel loops have been tiled).
+  //      CHECK:     hal.interface.store.tensor.tile %[[RES]], @legacy_io::@ret0, base_offset = %{{[0-9a-z]*}},
+  // CHECK-SAME:       offsets = [%[[I]], %[[J]]], sizes = [%c1, %c1], strides = [%c1, %c1] : tensor<1x1xf32>
+  %6 = scf.for %arg0 = %5 to %c2 step %c2 iter_args(%arg1 = %3) -> (tensor<2x4xf32>) {
+    %7 = scf.for %arg2 = %4 to %c4 step %c4 iter_args(%arg3 = %arg1) -> (tensor<2x4xf32>) {
+      %8 = subtensor %1[%arg0, 0] [1, 3] [1, 1] : tensor<2x3xf32> to tensor<1x3xf32>
+      %9 = subtensor %2[0, %arg2] [3, 1] [1, 1] : tensor<3x4xf32> to tensor<3x1xf32>
+      %10 = subtensor %0[%arg0, %arg2] [1, 1] [1, 1] : tensor<2x4xf32> to tensor<1x1xf32>
+      %11 = linalg.generic #trait ins(%10 : tensor<1x1xf32>) {
+      ^bb0(%arg4: f32):
+        linalg.yield %arg4 : f32
+      } -> tensor<1x1xf32>
+      %13 = linalg.matmul ins(%8, %9 : tensor<1x3xf32>, tensor<3x1xf32>)
+                         init(%11 : tensor<1x1xf32>) -> tensor<1x1xf32>
+      // This op is the destructive update we seek to eliminate.
+      %14 = subtensor_insert %13 into %arg3[%arg0, %arg2] [%c1, %c1] [%c1, %c1] :
+        tensor<1x1xf32> into tensor<2x4xf32>
+      scf.yield %14 : tensor<2x4xf32>
+    }
+    scf.yield %7 : tensor<2x4xf32>
+  }
+
+  // Under parallel iterators tiling assumptions, simple forwarding occurs.
+  //  CHECK-NOT: hal.interface.load.tensor %
+  //  CHECK-NOT: hal.interface.store.tensor %
+  // This is the store onto which the destructive update rewrites latches.
+  hal.interface.store.tensor %6, @legacy_io::@ret0, offset = %c0
+    {operand_result_index = 3 : i32} : tensor<2x4xf32>
+
+  return
+}
diff --git a/iree/compiler/Conversion/init_conversions.h b/iree/compiler/Conversion/init_conversions.h
index d53c0ee..05bcdc4 100644
--- a/iree/compiler/Conversion/init_conversions.h
+++ b/iree/compiler/Conversion/init_conversions.h
@@ -66,6 +66,7 @@
     createConvImg2ColMatmulConversionPass();
     createLinalgTileAndDistributePass();
     createLinalgTileAndVectorizeWorkgroupsPass();
+    createLinalgRewriteDestructiveUpdatesPass();
     return true;
   }();
   (void)init_once;