| // Copyright 2020 The IREE Authors |
| // |
| // Licensed under the Apache License v2.0 with LLVM Exceptions. |
| // See https://llvm.org/LICENSE.txt for license information. |
| // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| |
| //===- DestructiveUpdateUtilss.cpp - Utils to rewrite destructive updates--===// |
| // |
| // Implementation to rewrite Linalg on tensors destructive updates into updates |
| // through memory. |
| // |
| //===----------------------------------------------------------------------===// |
| |
| #include "iree/compiler/Dialect/Flow/IR/FlowOps.h" |
| #include "iree/compiler/Dialect/HAL/IR/HALOps.h" |
| #include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtInterfaces.h" |
| #include "iree/compiler/Dialect/Util/IR/UtilDialect.h" |
| #include "llvm/ADT/TypeSwitch.h" |
| #include "llvm/Support/Debug.h" |
| #include "mlir/Dialect/Linalg/IR/LinalgOps.h" |
| #include "mlir/Dialect/MemRef/IR/MemRef.h" |
| #include "mlir/Dialect/SCF/SCF.h" |
| #include "mlir/Dialect/StandardOps/IR/Ops.h" |
| #include "mlir/Dialect/Tensor/IR/Tensor.h" |
| #include "mlir/IR/Dominance.h" |
| #include "mlir/IR/PatternMatch.h" |
| #include "mlir/Pass/Pass.h" |
| #include "mlir/Support/LLVM.h" |
| #include "mlir/Transforms/GreedyPatternRewriteDriver.h" |
| |
| #define DEBUG_TYPE "iree-flow-linalg-rewrite-destructive-updates" |
| |
| namespace mlir { |
| namespace iree_compiler { |
| namespace IREE { |
| namespace Flow { |
| |
| // Detect the pattern: |
| // %d0 = for (iter_args %e0 = %0) |
| // ... |
| // %dk = for ((iter_args %ek = %e{k-1})) |
| // ... |
| // %dn = destructive-update-op (%en) |
| // yield %dn |
| // ... |
| // yield %dk |
| // yield %dk |
| struct SpecialTerminatorOpCapture { |
| Value initValue; |
| // For now, must be scf::ForOps. |
| SmallVector<Operation *, 4> loops; |
| // For now, must be a SubTensorInsertOp. |
| Operation *rootDestructiveUpdate; |
| bool readOnly = false; |
| bool writeOnly = false; |
| }; |
| |
| // TODO(nicolasvasilache): Use some interface instead of op names directly. |
| static bool hasDestructiveUpdateUses(BlockArgument arg, |
| SpecialTerminatorOpCapture &capture) { |
| SmallVector<Operation *> reads; |
| SmallVector<Operation *> writes; |
| for (OpOperand &u : arg.getUses()) { |
| TypeSwitch<Operation *, void>(u.getOwner()) |
| .Case<linalg::LinalgOp, linalg_ext::LinalgExtOp>( |
| [&](auto linalgLikeOp) { |
| if (linalgLikeOp.isOutputTensor(&u)) { |
| writes.push_back(linalgLikeOp); |
| } else { |
| reads.push_back(linalgLikeOp); |
| } |
| }) |
| .Case<tensor::InsertSliceOp>([&](tensor::InsertSliceOp sliceOp) { |
| if (sliceOp.dest() == u.get()) { |
| writes.push_back(sliceOp); |
| } else { |
| reads.push_back(sliceOp); |
| } |
| }) |
| .Default([&](Operation *op) { reads.push_back(op); }); |
| } |
| // For now, only allow exactly a single SubTensorInsertOp that must be |
| // dominated by all SubTensorOp. |
| if (writes.size() != 1) return false; |
| // Small local dominance computation. |
| DominanceInfo domInfo(writes.front()->getParentOp()); |
| for (auto read : reads) { |
| LLVM_DEBUG(llvm::dbgs() << "read: " << *read << "\n"); |
| if (!domInfo.properlyDominates(read, writes.front())) { |
| LLVM_DEBUG(llvm::dbgs() << "non-destructive use-def: " << *read |
| << " does not properly dominate " |
| << *(writes.front()) << "\n"); |
| return false; |
| } |
| } |
| |
| capture.readOnly = writes.empty(); |
| capture.writeOnly = reads.empty(); |
| capture.rootDestructiveUpdate = writes.front(); |
| LLVM_DEBUG(llvm::dbgs() << "readOnly: " << capture.readOnly |
| << " writeOnly: " << capture.writeOnly << "\n"); |
| return true; |
| } |
| |
| // Determine whether `tensor` is produced by a destructive update of another |
| // tensor. When successful, fill a SpecialTerminatorOpCapture struct that |
| // captures the relevant (distributed) pieces of IR that for the destructive |
| // update pattern. Iteratively traverse an (imperfectly nested) loop nest such |
| // as: |
| // |
| // ``` |
| // %d0 = for (iter_args %e0 = %0) |
| // ... |
| // %dk = for ((iter_args %ek = %e{k-1})) |
| // ... |
| // %dn = destructive-update-op (%en) |
| // yield %dn |
| // ... |
| // yield %dk |
| // ``` |
| // |
| // to determine whether `d0` is produced by a scf::ForOp with destructive |
| // update semantics. |
| // |
| // Return the value into which the destructive update occurs. |
| // Return nullptr if `tensor` is not a destructive update of some other tensor |
| // value. |
| static Value isADestructiveUpdatePattern(Value tensor, |
| SpecialTerminatorOpCapture &capture) { |
| Value returnValue; |
| while (auto scfForOp = dyn_cast_or_null<scf::ForOp>(tensor.getDefiningOp())) { |
| LLVM_DEBUG(llvm::dbgs() |
| << "Step destructive update pattern: " << scfForOp << "\n"); |
| // Capture the loop. |
| capture.loops.push_back(scfForOp); |
| // Analyze the iterArg at the proper position. |
| unsigned idx = tensor.cast<OpResult>().getResultNumber(); |
| BlockArgument regionArg = *(scfForOp.getRegionIterArgs().begin() + idx); |
| // Set return value if not yet set. |
| if (!returnValue) returnValue = *(scfForOp.getIterOperands().begin() + idx); |
| |
| // Case 1: zero use -> no destructive update. |
| if (regionArg.use_empty()) return nullptr; |
| |
| // Case 2: multiple uses from an scf::ForOp then this must be used only by |
| // SubTensorOp / SubTensorInsertOp with proper dominance. |
| if (!regionArg.hasOneUse()) { |
| if (!hasDestructiveUpdateUses(regionArg, capture)) return nullptr; |
| return returnValue; |
| } |
| |
| assert(regionArg.hasOneUse()); |
| LLVM_DEBUG(llvm::dbgs() << "one use analysis: " << regionArg << "\n"); |
| OpOperand *operand = regionArg.getUses().begin().getOperand(); |
| auto innerForOp = dyn_cast<scf::ForOp>(operand->getOwner()); |
| // Case 3a: Single use which is not an scf::ForOp, it may still be a |
| // single SubTensor / SubTensorInsertOp. |
| if (!innerForOp) { |
| if (!hasDestructiveUpdateUses(regionArg, capture)) return nullptr; |
| return returnValue; |
| } |
| |
| // Case 3b: Single use which is an scf::ForOp: `innerIterArgIdx` is the |
| // candidate result and iterArg number. |
| unsigned innerIterArgIdx = |
| operand->getOperandNumber() - innerForOp.getNumControlOperands(); |
| Value innerForOpResultTensor = innerForOp.getResult(innerIterArgIdx); |
| Value yieldValue = |
| scfForOp.region().front().getTerminator()->getOperand(idx); |
| |
| // Check that the return position of dk and the yield position of dk |
| // agree (in the loop structure below). This avoids ping-pong effects |
| // between operands, yields and results. |
| // |
| // %d0 = for (iter_args %e0 = %0) |
| // ... |
| // %dk = for ((iter_args %ek = %e{k-1})) |
| // ... |
| // %dn = destructive-update-op (%en) |
| // yield %dn |
| // ... |
| // yield %dk |
| LLVM_DEBUG(llvm::dbgs() |
| << "innerForOpResultTensor: " << innerForOpResultTensor << "\n" |
| << "yieldValue: " << yieldValue << "\n" |
| << "step in: " << (innerForOpResultTensor == yieldValue) |
| << "\n"); |
| if (innerForOpResultTensor != yieldValue) return nullptr; |
| |
| // Prepare for the next level with the innerForOp's result at position |
| // `innerIterArgIdx`. |
| tensor = innerForOp.getResult(innerIterArgIdx); |
| LLVM_DEBUG(llvm::dbgs() << "next tensor: " << tensor << "\n"); |
| } |
| return nullptr; |
| } |
| |
| /// Convert `subtensor %t [offsets][sizes][strides] -> %st` to a |
| /// flow.dispatch.tensor.load. |
| static LogicalResult propagateSubTensorOp(OpBuilder &b, |
| tensor::ExtractSliceOp op) { |
| OpBuilder::InsertionGuard g(b); |
| b.setInsertionPoint(op); |
| auto loadOp = op.source().getDefiningOp<IREE::Flow::DispatchTensorLoadOp>(); |
| if (!loadOp) { |
| BlockArgument val = op.source().dyn_cast<BlockArgument>(); |
| while (val) { |
| auto forOp = dyn_cast<scf::ForOp>(val.getOwner()->getParentOp()); |
| // val is a block argument but not to an scf::ForOp -> bail. |
| if (!forOp) return failure(); |
| unsigned idx = val.getArgNumber() - 1; // accounting for IV arg. |
| Value iterOperand = *(forOp.getIterOperands().begin() + idx); |
| loadOp = iterOperand.getDefiningOp<IREE::Flow::DispatchTensorLoadOp>(); |
| val = iterOperand.dyn_cast<BlockArgument>(); |
| } |
| } |
| if (!loadOp) return failure(); |
| |
| Value loaded = b.create<IREE::Flow::DispatchTensorLoadOp>( |
| op.getLoc(), op.getResult().getType(), loadOp.source(), op.offsets(), |
| op.sizes(), op.strides(), op.static_offsets(), op.static_sizes(), |
| op.static_strides()); |
| |
| op.getResult().replaceAllUsesWith(loaded); |
| op.erase(); |
| return success(); |
| } |
| |
| template <typename OpTy> |
| static LogicalResult rewriteDestructiveUpdateInPlace(OpBuilder &b, |
| OpTy linalgLikeOp, |
| Value target) { |
| LLVM_DEBUG(llvm::dbgs() << "RewriteDestructiveUpdateInPlace: " |
| << *linalgLikeOp.getOperation() << "\n"); |
| if (!linalgLikeOp->hasOneUse()) { |
| return linalgLikeOp.emitError("not a single use operation"); |
| } |
| |
| OpOperand &use = *(linalgLikeOp->use_begin()); |
| if (isa<scf::YieldOp>(use.getOwner())) { |
| OpResult usedResult = use.get().cast<OpResult>(); |
| Value dest = |
| linalgLikeOp.getOutputOperand(usedResult.getResultNumber())->get(); |
| if (!dest || !dest.isa<BlockArgument>()) { |
| return linalgLikeOp.emitError("dest is not a argument to the loop"); |
| } |
| OpBuilder::InsertionGuard g(b); |
| b.setInsertionPointAfter(linalgLikeOp); |
| |
| // Kills the SSA use-def chain. |
| usedResult.replaceAllUsesWith(dest); |
| |
| b.create<IREE::Flow::DispatchTensorStoreOp>(linalgLikeOp.getLoc(), |
| usedResult, target); |
| |
| return success(); |
| } |
| return failure(); |
| } |
| |
| /// Rewrites destructive in-place updates with the update operation being |
| /// tensor.insert_slice. |
| template <> |
| LogicalResult rewriteDestructiveUpdateInPlace<tensor::InsertSliceOp>( |
| OpBuilder &b, tensor::InsertSliceOp insertSliceOp, Value target) { |
| LLVM_DEBUG(llvm::dbgs() << "RewriteDestructiveUpdateInPlace: " |
| << *insertSliceOp.getOperation() << "\n"); |
| if (!insertSliceOp->hasOneUse()) { |
| return insertSliceOp.emitError("not a single use operation"); |
| } |
| |
| OpOperand &use = *(insertSliceOp->use_begin()); |
| if (isa<scf::YieldOp>(use.getOwner())) { |
| OpResult usedResult = use.get().cast<OpResult>(); |
| Value dest = insertSliceOp.dest(); |
| if (!dest || !dest.isa<BlockArgument>()) { |
| return insertSliceOp.emitError("dest is not a argument to the loop"); |
| } |
| OpBuilder::InsertionGuard g(b); |
| b.setInsertionPoint(insertSliceOp); |
| |
| // Kills the SSA use-def chain. |
| usedResult.replaceAllUsesWith(dest); |
| |
| b.create<IREE::Flow::DispatchTensorStoreOp>( |
| insertSliceOp->getLoc(), insertSliceOp.source(), target, |
| insertSliceOp.offsets(), insertSliceOp.sizes(), insertSliceOp.strides(), |
| insertSliceOp.static_offsets(), insertSliceOp.static_sizes(), |
| insertSliceOp.static_strides()); |
| |
| return success(); |
| } |
| return failure(); |
| } |
| |
| // Return true if any control flow is found in the DispatchWorkgroupsOp besides |
| // scf::ForOp. |
| static bool hasNonScfForControlFlow( |
| IREE::Flow::DispatchWorkgroupsOp dispatchOp) { |
| return dispatchOp |
| .walk([&](Operation *op) { |
| if (isa<BranchOpInterface>(op) || isa<RegionBranchOpInterface>(op)) { |
| if (!isa<scf::ForOp>(op) && |
| !isa<IREE::Flow::DispatchWorkgroupsOp>(op)) |
| return WalkResult::interrupt(); |
| } |
| return WalkResult::advance(); |
| }) |
| .wasInterrupted(); |
| } |
| |
| // Rewrite specific SubTensor / SubTensorInsert ops that match a "destructive |
| // tensor update" pattern, by an inplace update at `binding` and `offset1, using |
| // hal.interface.*.tensor.tile ops. |
| // This serves as a step in jumping the abstraction gap between transformed |
| // "linalg on tensors" IR and the buffer world. |
| // This is possible because we control the production of such patterns in IREE |
| // and can take the necessary shortcuts wrt inplace semantics. |
| // In the future it is reasonable to expect special IR constructs to capture |
| // some of the destructive update patterns, |
| // |
| // Assumptions/Invariants on "Control the Production of Such Patterns" |
| // =================================================================== |
| // 1. Input tensors may not participate in a destructive update pattern. |
| // 2. Init and output tensors may participate in a destructive update pattern. |
| // 3. No init or output tensor backing storage aliases with any other tensor |
| // storage. |
| // 4. SubTensorOp/SubTensorInsertOp are the only ops that can extract/insert |
| // from/into tensors. |
| // 5. All SubTensorOp/SubTensorInsertOp must have been introduced by Linalg |
| // tiling on tensors. |
| // 6. Such tilings that result in yielded tensors across loops may only tile |
| // parallel Linalg iterators atm. |
| // 7. (Future) Allow non-parallel Linalg iterators tiling and ensure first-read |
| // or writeOnly by construction. |
| // |
| // Note: the assumptions/invariants above are subject to changing ordering of |
| // passes. When dispatch region and hal.interfaces are created on the linalg on |
| // buffers path, these are all assumptions. In the future, when dispatch regions |
| // and hal.interfaces are created post-transformations on the linalg on tensors |
| // path some assumptions will become invariants. |
| // |
| // For now, the following destructive update patterns are rewritten. |
| // |
| // Coming from an `InterfaceLoadTensorOp` |
| // ====================================== |
| // ``` |
| // %0 = hal.interface.load.tensor @x[offsetx] |
| // ... |
| // %1 = destructive_update(%0) |
| // ... |
| // use_of(%1) // e.g. hal.interface.store.tensor %1 @y[offsety] |
| // ``` |
| // is rewritten into: |
| // ``` |
| // %0 = hal.interface.load.tensor @x[offsetx] |
| // ... |
| // inplace_update @binding[offset] |
| // %2 = hal.interface.load.tensor @binding[offset] |
| // ... |
| // use_of(%2) // e.g. hal.interface.store.tensor %2 @y[offsety] |
| // ``` |
| // |
| // This is a typical pattern that appears after tiling Linalg ops on tensors |
| // with operands that come from hal.interface. |
| // |
| // Coming from a `LinalgOp` |
| // ========================= |
| // ``` |
| // %0 = linalg-op |
| // ... |
| // %1 = destructive_update(%0) // only subtensor_inserts into %0 |
| // ... |
| // use_of(%1) // e.g. hal.interface.store.tensor %1 @y |
| // ``` |
| // is rewritten into: |
| // ``` |
| // %0 = linalg-op |
| // ... |
| // inplace_update @binding[offset] |
| // %2 = hal.interface.load.tensor @binding[offset] |
| // ... |
| // hal.interface.store.tensor %2 @y[offsety] |
| // ``` |
| // This is a typical pattern that appears after tileAndFuse ops with operands |
| // produced by other linalg ops. In this case, tile and fuse leaves %0 behind |
| // because it is the op that materializes the full tensor. This could be |
| // replaced by a notional "tensor.undef" and the compute would become a dead |
| // value. |
| // The rewrite breaks the use-def chain for %0 and may result in the linalg-op |
| // being DCE'd. |
| // |
| // Other rewrites: |
| // =============== |
| // Furthermore, when `@binding` == `@y` and `offset` == `offsety` and `...` |
| // contains no aliasing read/write to either `@binding[offset]` or `@y[offsety]` |
| // the following: |
| // ``` |
| // %2 = hal.interface.load.tensor @binding[offset] |
| // ... |
| // hal.interface.store.tensor %2 @y[offsety] |
| // ``` |
| // is elided. |
| // This should probably become a dedicated pass based on core alias analysis, |
| // when the latter becomes available. |
| static LogicalResult rewriteDestructiveUpdateInPlace( |
| OpBuilder &b, SpecialTerminatorOpCapture &capture, Value target) { |
| Operation *outermostProducingOp = (capture.loops.empty()) |
| ? capture.rootDestructiveUpdate |
| : capture.loops.front(); |
| LLVM_DEBUG(llvm::dbgs() << "outermost producing: " << *outermostProducingOp |
| << "\n"); |
| |
| // Try to rewrite inplace. |
| auto status = |
| TypeSwitch<Operation *, LogicalResult>(capture.rootDestructiveUpdate) |
| .Case<linalg::LinalgOp, linalg_ext::LinalgExtOp, |
| tensor::InsertSliceOp>([&](auto op) { |
| if (failed(rewriteDestructiveUpdateInPlace(b, op, target))) { |
| return failure(); |
| } |
| return success(); |
| }) |
| .Default([&](Operation *) { return failure(); }); |
| if (failed(status)) return failure(); |
| |
| if (scf::ForOp loopOp = dyn_cast<scf::ForOp>(outermostProducingOp)) |
| loopOp.walk( |
| [&](tensor::ExtractSliceOp op) { (void)propagateSubTensorOp(b, op); }); |
| |
| return success(); |
| } |
| |
| // TODO(nicolasvasilache): generalize to more than naive "top of the region |
| // consecutive ops". Probably better to wait until core alias analysis is |
| // upstreamed. |
| // TODO(nicolasvasilache): interfaces. |
| static bool hasInterleavedAliases(IREE::Flow::DispatchTensorLoadOp loadOp, |
| IREE::Flow::DispatchTensorStoreOp storeOp) { |
| Block *bLoad = loadOp.getOperation()->getBlock(); |
| Block *bStore = loadOp.getOperation()->getBlock(); |
| if (!isa<IREE::Flow::DispatchWorkgroupsOp>(bLoad->getParentOp()) || |
| !isa<IREE::Flow::DispatchWorkgroupsOp>(bStore->getParentOp()) || |
| bLoad->getParentOp() != bStore->getParentOp()) |
| return true; |
| |
| if (storeOp.getOperation()->getPrevNode() != loadOp) return true; |
| |
| return false; |
| } |
| |
| LogicalResult rewriteLinalgDestructiveUpdates( |
| IREE::Flow::DispatchWorkgroupsOp dispatchOp) { |
| // Bail on any control-flow for now. |
| if (hasNonScfForControlFlow(dispatchOp)) { |
| return success(); |
| } |
| |
| MLIRContext *context = dispatchOp->getContext(); |
| OpBuilder b(context); |
| SmallVector<IREE::Flow::DispatchTensorStoreOp> processedStores; |
| // For each tensor store op, look for destructive updates and replace the |
| // destructive pattern by a custom inplace update pattern. |
| bool fail = dispatchOp |
| .walk([&](IREE::Flow::DispatchTensorStoreOp op) { |
| SpecialTerminatorOpCapture capture; |
| capture.initValue = op.value(); |
| Value sourceValue = |
| isADestructiveUpdatePattern(capture.initValue, capture); |
| if (!sourceValue) { |
| return WalkResult::advance(); |
| } |
| if (failed(rewriteDestructiveUpdateInPlace(b, capture, |
| op.target()))) { |
| return WalkResult::interrupt(); |
| } |
| processedStores.push_back(op); |
| return WalkResult::advance(); |
| }) |
| .wasInterrupted(); |
| if (fail) return failure(); |
| |
| for (auto op : processedStores) { |
| op.erase(); |
| } |
| |
| // Non-default canonicalization patterns. |
| // TODO(nicolasvasilache): add Linalg tiling canonicalization patterns, |
| // affineminscf and others as needed. |
| OwningRewritePatternList canonicalizationPatterns(context); |
| scf::ForOp::getCanonicalizationPatterns(canonicalizationPatterns, context); |
| (void)applyPatternsAndFoldGreedily(dispatchOp, |
| std::move(canonicalizationPatterns)); |
| return success(); |
| } |
| |
| } // namespace Flow |
| } // namespace IREE |
| } // namespace iree_compiler |
| } // namespace mlir |