Adding -iree-stream-materialize-copy-on-write pass. (#7527)
This pass, along with the associated `-iree-stream-elide-async-copies` cleanup pass, implements the copy-on-write logic in the stream dialect. The first pass conservatively inserts copies to tensors that are mutated with remaining uses and the cleanup pass tries to perform some basic analysis to elide them.
diff --git a/iree/compiler/Dialect/Stream/IR/StreamOpFolders.cpp b/iree/compiler/Dialect/Stream/IR/StreamOpFolders.cpp
index 5a3906f..473df80 100644
--- a/iree/compiler/Dialect/Stream/IR/StreamOpFolders.cpp
+++ b/iree/compiler/Dialect/Stream/IR/StreamOpFolders.cpp
@@ -646,7 +646,6 @@
void TensorImportOp::getCanonicalizationPatterns(
OwningRewritePatternList &results, MLIRContext *context) {
// TODO(benvanik): check operand and dedupe imports.
- results.insert<MaterializeCOW<TensorImportOp>>(context);
}
//===----------------------------------------------------------------------===//
@@ -836,7 +835,6 @@
// TODO(benvanik): alloca (staging) -> non-staging change to target.
// TODO(benvanik): alloca (non-staging) -> staging change to target.
// TODO(benvanik): sink to first user.
- results.insert<MaterializeCOW<AsyncAllocaOp>>(context);
}
//===----------------------------------------------------------------------===//
@@ -847,7 +845,6 @@
OwningRewritePatternList &results, MLIRContext *context) {
// TODO(benvanik): if value is a splat turn into splat.
// TODO(benvanik): if value is _mostly_ a splat, turn into splat + updates.
- results.insert<MaterializeCOW<AsyncConstantOp>>(context);
}
//===----------------------------------------------------------------------===//
@@ -923,7 +920,6 @@
// TODO(#6972): clone instead of sinking to common dominator.
results.insert<SinkSplatsToConsumers>(context);
results.insert<ElideUnusedOp<AsyncSplatOp>>(context);
- results.insert<MaterializeCOW<AsyncSplatOp>>(context);
}
//===----------------------------------------------------------------------===//
@@ -973,7 +969,6 @@
// TODO(benvanik): some way to reduce deep clone->clone->clone chains.
results.insert<PropagateClonableOps>(context);
results.insert<ElideUnusedOp<AsyncCloneOp>>(context);
- results.insert<MaterializeCOW<AsyncCloneOp>>(context);
}
//===----------------------------------------------------------------------===//
@@ -1020,7 +1015,6 @@
// affinity/lifetime differ.
results.insert<PropagateSplatsThroughSlices>(context);
results.insert<ElideUnusedOp<AsyncSliceOp>>(context);
- results.insert<MaterializeCOW<AsyncSliceOp>>(context);
}
//===----------------------------------------------------------------------===//
@@ -1057,7 +1051,6 @@
MLIRContext *context) {
results.insert<FlattenFullFillToSplat>(context);
results.insert<ElideUnusedOp<AsyncFillOp>>(context);
- results.insert<MaterializeCOW<AsyncFillOp>>(context);
}
//===----------------------------------------------------------------------===//
@@ -1147,7 +1140,6 @@
results.insert<CombineSplatUpdateFromToFill>(context);
results.insert<CombineSliceUpdateFromToCopy>(context);
results.insert<ElideUnusedOp<AsyncUpdateOp>>(context);
- results.insert<MaterializeCOW<AsyncUpdateOp>>(context);
}
//===----------------------------------------------------------------------===//
@@ -1185,7 +1177,6 @@
MLIRContext *context) {
results.insert<AsyncCopyFullSourceToUpdate>(context);
results.insert<ElideUnusedOp<AsyncCopyOp>>(context);
- results.insert<MaterializeCOW<AsyncCopyOp>>(context);
}
//===----------------------------------------------------------------------===//
@@ -1227,7 +1218,6 @@
// TODO(benvanik): staging propagation (fill of staging -> fill on device).
results.insert<RedundantTransferElision>(context);
results.insert<ElideUnusedOp<AsyncTransferOp>>(context);
- results.insert<MaterializeCOW<AsyncTransferOp>>(context);
}
//===----------------------------------------------------------------------===//
@@ -1261,7 +1251,6 @@
OwningRewritePatternList &results, MLIRContext *context) {
// TODO(benvanik): nothing? maybe tied type/lifetime updates?
results.insert<ElideUnusedOp<AsyncDispatchOp>>(context);
- results.insert<MaterializeCOW<AsyncDispatchOp>>(context);
}
//===----------------------------------------------------------------------===//
@@ -1452,7 +1441,6 @@
context);
results.insert<TieRegionResults<AsyncExecuteOp>>(context);
results.insert<ElideUnusedOp<AsyncExecuteOp>>(context);
- results.insert<MaterializeCOW<AsyncExecuteOp>>(context);
}
//===----------------------------------------------------------------------===//
@@ -1465,7 +1453,6 @@
context);
results.insert<TieRegionResults<AsyncConcurrentOp>>(context);
results.insert<ElideUnusedOp<AsyncConcurrentOp>>(context);
- results.insert<MaterializeCOW<AsyncConcurrentOp>>(context);
}
//===----------------------------------------------------------------------===//
diff --git a/iree/compiler/Dialect/Stream/Transforms/BUILD b/iree/compiler/Dialect/Stream/Transforms/BUILD
index 9e06012..3325787 100644
--- a/iree/compiler/Dialect/Stream/Transforms/BUILD
+++ b/iree/compiler/Dialect/Stream/Transforms/BUILD
@@ -16,7 +16,9 @@
name = "Transforms",
srcs = [
"ConvertToStream.cpp",
+ "ElideAsyncCopies.cpp",
"EncodeTensors.cpp",
+ "MaterializeCopyOnWrite.cpp",
"OutlineConstants.cpp",
"PassDetail.h",
"Passes.cpp",
@@ -36,6 +38,8 @@
"//iree/compiler/Dialect/Stream/Conversion/StandardToStream",
"//iree/compiler/Dialect/Stream/Conversion/UtilToStream",
"//iree/compiler/Dialect/Stream/IR",
+ "//iree/compiler/Dialect/Util/Analysis",
+ "//iree/compiler/Dialect/Util/Analysis/DFX",
"//iree/compiler/Dialect/Util/Conversion",
"//iree/compiler/Dialect/Util/IR",
"//iree/compiler/Dialect/Util/Transforms",
diff --git a/iree/compiler/Dialect/Stream/Transforms/CMakeLists.txt b/iree/compiler/Dialect/Stream/Transforms/CMakeLists.txt
index 82f8335..14ebef4 100644
--- a/iree/compiler/Dialect/Stream/Transforms/CMakeLists.txt
+++ b/iree/compiler/Dialect/Stream/Transforms/CMakeLists.txt
@@ -18,7 +18,9 @@
"Passes.h.inc"
SRCS
"ConvertToStream.cpp"
+ "ElideAsyncCopies.cpp"
"EncodeTensors.cpp"
+ "MaterializeCopyOnWrite.cpp"
"OutlineConstants.cpp"
"PassDetail.h"
"Passes.cpp"
@@ -45,6 +47,8 @@
iree::compiler::Dialect::Stream::Conversion::StandardToStream
iree::compiler::Dialect::Stream::Conversion::UtilToStream
iree::compiler::Dialect::Stream::IR
+ iree::compiler::Dialect::Util::Analysis
+ iree::compiler::Dialect::Util::Analysis::DFX
iree::compiler::Dialect::Util::Conversion
iree::compiler::Dialect::Util::IR
iree::compiler::Dialect::Util::Transforms
diff --git a/iree/compiler/Dialect/Stream/Transforms/ElideAsyncCopies.cpp b/iree/compiler/Dialect/Stream/Transforms/ElideAsyncCopies.cpp
new file mode 100644
index 0000000..27e09e1
--- /dev/null
+++ b/iree/compiler/Dialect/Stream/Transforms/ElideAsyncCopies.cpp
@@ -0,0 +1,531 @@
+// Copyright 2021 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include <utility>
+
+#include "iree/compiler/Dialect/Stream/IR/StreamDialect.h"
+#include "iree/compiler/Dialect/Stream/IR/StreamOps.h"
+#include "iree/compiler/Dialect/Stream/Transforms/PassDetail.h"
+#include "iree/compiler/Dialect/Stream/Transforms/Passes.h"
+#include "iree/compiler/Dialect/Util/Analysis/DFX/Element.h"
+#include "iree/compiler/Dialect/Util/Analysis/DFX/Solver.h"
+#include "iree/compiler/Dialect/Util/Analysis/DFX/State.h"
+#include "iree/compiler/Dialect/Util/Analysis/Explorer.h"
+#include "iree/compiler/Dialect/Util/IR/UtilDialect.h"
+#include "iree/compiler/Dialect/Util/IR/UtilOps.h"
+#include "llvm/ADT/TypeSwitch.h"
+#include "llvm/Support/Debug.h"
+#include "mlir/Analysis/Liveness.h"
+#include "mlir/IR/AsmState.h"
+#include "mlir/IR/Attributes.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/Diagnostics.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+#define DEBUG_TYPE "iree-stream-elide-async-copies"
+
+namespace mlir {
+namespace iree_compiler {
+namespace IREE {
+namespace Stream {
+namespace {
+
+//===----------------------------------------------------------------------===//
+// Resource usage query/application patterns
+//===----------------------------------------------------------------------===//
+
+// TODO(benvanik): change this to just be an AbstractState - there's no real
+// need for PVS as we don't track dynamically and are just using this as a
+// cache.
+class LastUsers
+ : public DFX::StateWrapper<DFX::PotentialValuesState<Operation *>,
+ DFX::ValueElement> {
+ public:
+ using BaseType = DFX::StateWrapper<DFX::PotentialValuesState<Operation *>,
+ DFX::ValueElement>;
+
+ static LastUsers &createForPosition(const Position &pos,
+ DFX::Solver &solver) {
+ return *(new (solver.getAllocator()) LastUsers(pos));
+ }
+
+ const std::string getName() const override { return "LastUsers"; }
+ const void *getID() const override { return &ID; }
+ static bool classof(const DFX::AbstractElement *element) {
+ return (element->getID() == &ID);
+ }
+ static const char ID;
+
+ // Returns true if the given |op| is known to be a last user of the value.
+ // Note that a single op may use a value multiple times.
+ bool isAssumedLastUser(Operation *op) const {
+ return getAssumedSet().contains(op);
+ }
+
+ const std::string getAsStr() const override {
+ return std::string("last users: ") + std::to_string(getAssumedSet().size());
+ }
+
+ private:
+ explicit LastUsers(const Position &pos) : BaseType(pos) {}
+
+ void initializeValue(Value value, DFX::Solver &solver) override {
+ // NOTE: this is only for the local region; we don't touch transitive users.
+ // TODO(benvanik): touch transitive users? We could evaluate with
+ // solver.getExplorer().walkTransitiveUsers() and ensure all tied uses
+ // go out of scope at the right time. For now we assume that the SSA
+ // value last users are all we care about.
+ auto parentOp =
+ value.getParentRegion()->getParentOfType<mlir::CallableOpInterface>();
+ auto liveness = solver.getExplorer()
+ .getAnalysisManager()
+ .nest(parentOp)
+ .getAnalysis<Liveness>();
+ for (auto user : value.getUsers()) {
+ if (liveness.isDeadAfter(value, user)) {
+ unionAssumed(user);
+ }
+ }
+ indicateOptimisticFixpoint();
+
+ LLVM_DEBUG({
+ AsmState asmState(value.getParentBlock()->getParentOp());
+ llvm::dbgs() << "[elide-copies] initialized value last users for ";
+ value.printAsOperand(llvm::dbgs(), asmState);
+ llvm::dbgs() << ": " << getAssumedSet().size() << "\n";
+ for (auto user : getAssumedSet()) {
+ llvm::dbgs() << " ";
+ user->print(llvm::dbgs(), OpPrintingFlags().elideLargeElementsAttrs());
+ llvm::dbgs() << "\n";
+ }
+ });
+ }
+
+ ChangeStatus updateValue(Value value, DFX::Solver &solver) override {
+ // NOTE: this is purely a cache and is based only on the initial value;
+ // this should never be called.
+ return ChangeStatus::UNCHANGED;
+ }
+
+ friend class DFX::Solver;
+};
+const char LastUsers::ID = 0;
+
+class ArgumentSemantics
+ : public DFX::StateWrapper<DFX::BitIntegerState<uint8_t, 3, 0>,
+ DFX::ValueElement> {
+ public:
+ using BaseType =
+ DFX::StateWrapper<DFX::BitIntegerState<uint8_t, 3, 0>, DFX::ValueElement>;
+
+ // Inverted bits so that we can go from best (all bits set) to worst (no bits
+ // set).
+ enum {
+ // Argument is _not_ mutated within the region it is used.
+ NOT_MUTATED = 1u << 0,
+ // Argument is _not_ by reference (so: by value). Indicates that the
+ // argument is not retained at any predecessor/caller and is owned by the
+ // receiver.
+ NOT_BY_REFERENCE = 1u << 1,
+
+ BEST_STATE = NOT_MUTATED | NOT_BY_REFERENCE,
+ };
+ static_assert(BEST_STATE == BaseType::getBestState(),
+ "unexpected BEST_STATE value");
+
+ static ArgumentSemantics &createForPosition(const Position &pos,
+ DFX::Solver &solver) {
+ return *(new (solver.getAllocator()) ArgumentSemantics(pos));
+ }
+
+ const std::string getName() const override { return "ArgumentSemantics"; }
+ const void *getID() const override { return &ID; }
+ static bool classof(const DFX::AbstractElement *element) {
+ return (element->getID() == &ID);
+ }
+ static const char ID;
+
+ // Returns true if the argument is known to be passed by-value from all
+ // predecessors/callers.
+ bool getKnownByValue() const {
+ return (this->getKnown() & NOT_BY_REFERENCE) == NOT_BY_REFERENCE;
+ }
+
+ // Returns true if the argument is assumed to be passed by-value from all
+ // predecessors/callers.
+ bool getAssumedByValue() const {
+ return (this->getAssumed() & NOT_BY_REFERENCE) == NOT_BY_REFERENCE;
+ }
+
+ const std::string getAsStr() const override {
+ std::string str;
+ auto append = [&](const char *part) {
+ if (!str.empty()) str += '|';
+ str += part;
+ };
+ append(this->isAssumed(NOT_MUTATED) ? "immutable" : "mutable");
+ append(this->isAssumed(NOT_BY_REFERENCE) ? "by-value" : "by-reference");
+ return str.empty() ? "*" : str;
+ }
+
+ private:
+ explicit ArgumentSemantics(const Position &pos) : BaseType(pos) {}
+
+ // Returns true if |operand| is tied to a result on its owner indicating an
+ // in-place operation.
+ static bool isTiedUse(OpOperand &operand) {
+ if (auto tiedOp =
+ dyn_cast<IREE::Util::TiedOpInterface>(operand.getOwner())) {
+ if (tiedOp.isOperandTied(operand.getOperandNumber())) return true;
+ }
+ return false;
+ }
+
+ // Starts analysis of the |value| with known bits based on IR structure.
+ void initializeValue(Value value, DFX::Solver &solver) override {
+ // Start as NOT_MUTATED and NOT_BY_REFERENCE (by-value).
+ intersectAssumedBits(BEST_STATE);
+
+ // If any use is tied then we know we are mutated in-place.
+ // Note that this walks into call targets and across branches.
+ auto traversalResult = solver.getExplorer().walkTransitiveUses(
+ value, [&](OpOperand &operand) -> WalkResult {
+ if (isTiedUse(operand)) {
+ // Mutated in-place; nothing more we need to do.
+ removeKnownBits(NOT_MUTATED);
+ return WalkResult::interrupt();
+ }
+ return WalkResult::advance();
+ });
+ if (traversalResult == TraversalResult::INCOMPLETE) {
+ // Analysis incomplete - mark as conservatively by reference/mutated.
+ removeKnownBits(NOT_MUTATED | NOT_BY_REFERENCE);
+ }
+ }
+
+ // Updates the element state based on _a_ predecessor operand that is the
+ // source of the argument value. Will be called once per predecessors/caller.
+ void updateFromPredecessorUse(OpOperand &operand, DFX::Solver &solver) {
+ // If the operand is a block argument then we need to ask for the argument
+ // semantics first - if it's by reference then it's definitely not the last
+ // use and we can short-circuit this.
+ if (auto arg = operand.get().dyn_cast<BlockArgument>()) {
+ auto &argumentSemantics = solver.getElementFor<ArgumentSemantics>(
+ *this, Position::forValue(operand.get()), DFX::Resolution::REQUIRED);
+ LLVM_DEBUG(llvm::dbgs() << " pred is arg; combining state: "
+ << argumentSemantics.getAsStr() << "\n");
+ getState() ^= argumentSemantics.getState();
+ }
+
+ auto &lastUsers = solver.getElementFor<LastUsers>(
+ *this, Position::forValue(operand.get()), DFX::Resolution::REQUIRED);
+ bool isLastUser = lastUsers.isAssumedLastUser(operand.getOwner());
+ if (!isLastUser) {
+ // Not the last user - value is passed in by reference.
+ LLVM_DEBUG(llvm::dbgs() << " not the last user\n");
+ removeAssumedBits(NOT_BY_REFERENCE | NOT_MUTATED);
+ }
+ }
+
+ // Updates the semantics of |value| by walking all predecessors/callers (up
+ // through function arguments, branch arguments, and tied results) and all
+ // transitive uses (down through function calls, branches, and tied operands)
+ // by way of usage analysis.
+ ChangeStatus updateValue(Value value, DFX::Solver &solver) override {
+ auto assumedBits = getAssumed();
+ auto traversalResult = TraversalResult::COMPLETE;
+
+ auto arg = value.cast<BlockArgument>();
+ bool isEntryArg = arg.getParentBlock()->isEntryBlock();
+ if (isEntryArg) {
+ // Call argument.
+ auto callableOp =
+ cast<mlir::CallableOpInterface>(arg.getParentBlock()->getParentOp());
+ traversalResult |= solver.getExplorer().walkIncomingCalls(
+ callableOp, [&](mlir::CallOpInterface callOp) -> WalkResult {
+ unsigned baseIdx = callOp.getArgOperands().getBeginOperandIndex();
+ auto &sourceOperand =
+ callOp->getOpOperand(baseIdx + arg.getArgNumber());
+ updateFromPredecessorUse(sourceOperand, solver);
+ return WalkResult::advance();
+ });
+ } else {
+ // Branch argument.
+ traversalResult |= solver.getExplorer().walkIncomingBranchOperands(
+ arg.getParentBlock(),
+ [&](Block *sourceBlock, OperandRange operands) -> WalkResult {
+ unsigned baseIdx = operands.getBeginOperandIndex();
+ auto &sourceOperand = sourceBlock->getTerminator()->getOpOperand(
+ baseIdx + arg.getArgNumber());
+ updateFromPredecessorUse(sourceOperand, solver);
+ return WalkResult::advance();
+ });
+ }
+
+ if (traversalResult == TraversalResult::INCOMPLETE) {
+ LLVM_DEBUG(
+ llvm::dbgs()
+ << " !! traversal result incomplete; assuming by reference\n");
+ removeAssumedBits(NOT_BY_REFERENCE | NOT_MUTATED);
+ }
+ return assumedBits == getAssumed() ? ChangeStatus::UNCHANGED
+ : ChangeStatus::CHANGED;
+ }
+
+ friend class DFX::Solver;
+};
+const char ArgumentSemantics::ID = 0;
+
+// TODO(benvanik): change into something we can use for ref counting. We need
+// that to insert stream-ordered deallocs and know when timepoints have been
+// discard as they go out of scope. For now this strictly checks last use.
+class LastUseAnalysis {
+ public:
+ explicit LastUseAnalysis(Operation *rootOp)
+ : explorer(rootOp, TraversalAction::SHALLOW),
+ solver(explorer, allocator) {
+ explorer.setOpAction<IREE::Util::InitializerOp>(TraversalAction::RECURSE);
+ explorer.setOpAction<mlir::FuncOp>(TraversalAction::RECURSE);
+ explorer.setDialectAction<IREE::Stream::StreamDialect>(
+ TraversalAction::RECURSE);
+ // Ignore the contents of executables (linalg goo, etc).
+ explorer.setOpAction<IREE::Stream::ExecutableOp>(TraversalAction::IGNORE);
+ explorer.initialize();
+
+ assert(rootOp->getNumRegions() == 1 && "expected module-like root op");
+ topLevelOps = llvm::to_vector<4>(
+ rootOp->getRegions().front().getOps<mlir::CallableOpInterface>());
+ }
+
+ // Runs analysis and populates the state cache.
+ // May fail if analysis cannot be completed due to unsupported or unknown IR.
+ LogicalResult run() {
+ // Seed all block arguments throughout the program.
+ for (auto callableOp : getTopLevelOps()) {
+ for (auto &block : *callableOp.getCallableRegion()) {
+ for (auto arg : block.getArguments()) {
+ if (arg.getType().isa<IREE::Stream::ResourceType>()) {
+ solver.getOrCreateElementFor<ArgumentSemantics>(
+ Position::forValue(arg));
+ }
+ }
+ }
+ }
+
+ // Run solver to completion.
+ return solver.run();
+ }
+
+ // Returns a list of all top-level callable ops in the root op.
+ ArrayRef<mlir::CallableOpInterface> getTopLevelOps() const {
+ return topLevelOps;
+ }
+
+ // Returns true if block argument |arg| is passed in by-value/move (it's the
+ // last use from all callers/predecessor branches). When false the value
+ // represented by the argument may have other uses outside of its block.
+ bool isArgMoved(BlockArgument arg) {
+ auto argumentSemantics =
+ solver.lookupElementFor<ArgumentSemantics>(Position::forValue(arg));
+ if (!argumentSemantics) return false;
+ return argumentSemantics->getAssumedByValue();
+ }
+
+ // Returns true if |userOp| is the last user of |operand|.
+ bool isLastUser(Value operand, Operation *userOp) {
+ auto lastUsers =
+ solver.getOrCreateElementFor<LastUsers>(Position::forValue(operand));
+ return lastUsers.isAssumedLastUser(userOp);
+ }
+
+ private:
+ Explorer explorer;
+ llvm::BumpPtrAllocator allocator;
+ DFX::Solver solver;
+ SmallVector<mlir::CallableOpInterface> topLevelOps;
+};
+
+// Returns true if the given |operand| value does not need a copy on write.
+// This is a conservative check and will return false ("not safe to elide") in
+// many cases that otherwise don't need a copy. The
+// -iree-stream-elide-async-copies pass will do a whole-program analysis and
+// remove the copies we insert here when possible.
+//
+// No-op clone is elidable:
+// %0 ---> %1 = clone(%0) ---> use(%1) // last use of %0
+//
+// Clone required for correctness:
+// %0 ---> %1 = clone(%0) ---> use(%1)
+// \--> use(%0)
+//
+// Second clone elidable, first required:
+// %0 ---> %1 = clone(%0) ---> use(%1)
+// \--> %2 = clone(%0) ---> use(%2) // last use of %0
+static bool isSafeToElideCloneOp(IREE::Stream::AsyncCloneOp cloneOp,
+ LastUseAnalysis &analysis) {
+ LLVM_DEBUG({
+ llvm::dbgs() << "isSafeToElideCloneOp:\n";
+ llvm::dbgs() << " ";
+ cloneOp.print(llvm::dbgs(), OpPrintingFlags().elideLargeElementsAttrs());
+ llvm::dbgs() << "\n";
+ });
+
+ // If this clone is performing a type change we need to preserve it.
+ // TODO(benvanik): remove this carveout - could make clone not change type
+ // and transfer be needed instead.
+ auto sourceType =
+ cloneOp.source().getType().cast<IREE::Stream::ResourceType>();
+ auto targetType =
+ cloneOp.result().getType().cast<IREE::Stream::ResourceType>();
+ if (sourceType != targetType &&
+ sourceType.getLifetime() == IREE::Stream::Lifetime::Constant) {
+ LLVM_DEBUG(llvm::dbgs()
+ << " + clone source is a constant; cannot elide\n");
+ return false;
+ }
+
+ // If the source is a block argument we have to look into the analysis cache
+ // to see if it's been classified as a last use/by-value move. If it isn't
+ // then we cannot mutate it in-place as it could be used by the caller/another
+ // branch and we need to respect the forking of the value.
+ if (auto arg = cloneOp.source().dyn_cast<BlockArgument>()) {
+ if (!analysis.isArgMoved(arg)) {
+ LLVM_DEBUG(llvm::dbgs()
+ << " - clone source is a by-ref arg; cannot elide\n");
+ return false;
+ }
+ LLVM_DEBUG(llvm::dbgs()
+ << " ? clone source is a by-value arg; may elide\n");
+ }
+
+ // If there's only one user of the source we know it's this clone and can
+ // bypass all the more expensive liveness analysis.
+ if (cloneOp.source().hasOneUse()) {
+ LLVM_DEBUG(llvm::dbgs()
+ << " + clone source SSA value has one use; can elide\n");
+ return true;
+ }
+
+ // If this is the last user of the source SSA value then we can elide the
+ // clone knowing that any mutations won't impact the source.
+ if (analysis.isLastUser(cloneOp.source(), cloneOp)) {
+ LLVM_DEBUG(llvm::dbgs() << " + clone source use is the last; can elide\n");
+ return true;
+ }
+
+ // Not safe.
+ LLVM_DEBUG(llvm::dbgs() << " - clone source cannot be elided\n");
+ return false;
+}
+
+// Tries to elide |cloneOp| by replacing all uses with its source if safe.
+// Returns true if the op was elided.
+static bool tryElideCloneOp(IREE::Stream::AsyncCloneOp cloneOp,
+ LastUseAnalysis &analysis) {
+ if (!isSafeToElideCloneOp(cloneOp, analysis)) return false;
+ cloneOp.replaceAllUsesWith(cloneOp.source());
+ cloneOp.erase();
+ return true;
+}
+
+// Tries to elide copies nested within |region| when safe.
+// Returns true if any ops were elided.
+static bool tryElideAsyncCopiesInRegion(Region ®ion,
+ LastUseAnalysis &analysis) {
+ bool didChange = false;
+ for (auto &block : region) {
+ for (auto cloneOp : llvm::make_early_inc_range(
+ block.getOps<IREE::Stream::AsyncCloneOp>())) {
+ if (!isSafeToElideCloneOp(cloneOp, analysis)) continue;
+ cloneOp.replaceAllUsesWith(cloneOp.source());
+ cloneOp.erase();
+ didChange = true;
+ }
+ }
+ return didChange;
+}
+
+//===----------------------------------------------------------------------===//
+// -iree-stream-elide-async-copies
+//===----------------------------------------------------------------------===//
+
+// Elides async copies that perform no meaningful work - such as clones of the
+// last use of a value. This is designed to be run after
+// -iree-stream-materialize-copy-on-write to clean up the copies it introduces
+// but will also pick up any copies that came from the frontend.
+//
+// This should never remove copies that are required for correctness: we err on
+// the side of leaving copies when we cannot perform full analysis.
+//
+// This operates using a whole-program data flow analysis to first determine
+// which block arguments have move semantics (they are passed the last use of
+// a resource) and the last users of all cloned values. Once analyzed all copies
+// in the program are checked to see if they can be safely removed and if so are
+// rerouted to the cloned source value. This process repeats until no more
+// copies are elided: we are guaranteed to reach a fixed point as we are only
+// removing copies in this pass and not introducing any new ops.
+class ElideAsyncCopiesPass : public ElideAsyncCopiesBase<ElideAsyncCopiesPass> {
+ public:
+ ElideAsyncCopiesPass() = default;
+
+ void getDependentDialects(DialectRegistry ®istry) const override {
+ registry.insert<IREE::Stream::StreamDialect>();
+ registry.insert<IREE::Util::UtilDialect>();
+ }
+
+ void runOnOperation() override {
+ auto moduleOp = getOperation();
+ if (moduleOp.getBody()->empty()) return;
+
+ // Try analyzing the program and eliding the unneeded copies until we reach
+ // a fixed point (no more copies can be elided).
+ unsigned maxIterationCount = 30;
+ unsigned iterationCount = 0;
+ for (; iterationCount < maxIterationCount; ++iterationCount) {
+ // Perform whole-program analysis.
+ // TODO(benvanik): reuse allocator across iterations.
+ LastUseAnalysis analysis(moduleOp);
+ if (failed(analysis.run())) {
+ moduleOp.emitError() << "failed to solve for last users";
+ return signalPassFailure();
+ }
+
+ // Apply analysis by eliding all copies that are safe to elide.
+ // If we can't elide any we'll consider the iteration complete and exit.
+ bool didChange = false;
+ for (auto callableOp : analysis.getTopLevelOps()) {
+ didChange = tryElideAsyncCopiesInRegion(*callableOp.getCallableRegion(),
+ analysis) ||
+ didChange;
+ }
+ if (!didChange) break;
+ }
+ if (iterationCount == maxIterationCount) {
+ // If you find yourself hitting this we can evaluate increasing the
+ // iteration count (if it would eventually converge) or whether we allow
+ // this to happen without remarking. For now all our programs coverge in
+ // just one or two iterations and this needs to be tuned with more complex
+ // control flow.
+ moduleOp.emitRemark()
+ << "copy elision pass failed to reach a fixed point after "
+ << maxIterationCount << " iterations; unneeded copies may be present";
+ return;
+ }
+ }
+};
+
+} // namespace
+
+std::unique_ptr<OperationPass<mlir::ModuleOp>> createElideAsyncCopiesPass() {
+ return std::make_unique<ElideAsyncCopiesPass>();
+}
+
+} // namespace Stream
+} // namespace IREE
+} // namespace iree_compiler
+} // namespace mlir
diff --git a/iree/compiler/Dialect/Stream/Transforms/MaterializeCopyOnWrite.cpp b/iree/compiler/Dialect/Stream/Transforms/MaterializeCopyOnWrite.cpp
new file mode 100644
index 0000000..b5fc0ea
--- /dev/null
+++ b/iree/compiler/Dialect/Stream/Transforms/MaterializeCopyOnWrite.cpp
@@ -0,0 +1,186 @@
+// Copyright 2021 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include <utility>
+
+#include "iree/compiler/Dialect/Stream/IR/StreamDialect.h"
+#include "iree/compiler/Dialect/Stream/IR/StreamOps.h"
+#include "iree/compiler/Dialect/Stream/Transforms/PassDetail.h"
+#include "iree/compiler/Dialect/Stream/Transforms/Passes.h"
+#include "iree/compiler/Dialect/Util/IR/UtilDialect.h"
+#include "iree/compiler/Dialect/Util/IR/UtilOps.h"
+#include "llvm/ADT/TypeSwitch.h"
+#include "llvm/Support/Debug.h"
+#include "mlir/Dialect/StandardOps/IR/Ops.h"
+#include "mlir/IR/Attributes.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/Diagnostics.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+#define DEBUG_TYPE "iree-stream-materialize-copy-on-write"
+
+namespace mlir {
+namespace iree_compiler {
+namespace IREE {
+namespace Stream {
+namespace {
+
+//===----------------------------------------------------------------------===//
+// Copy-on-write (🐄)
+//===----------------------------------------------------------------------===//
+
+// Returns true if the given |operand| value does not need a copy on write.
+// This is a conservative check and will return false ("not safe to elide") in
+// many cases that otherwise don't need a copy. The
+// -iree-stream-elide-async-copies pass will do a whole-program analysis and
+// remove the copies we insert here when possible.
+static bool isSafeToElideCOW(Value operand, IREE::Stream::ResourceType type) {
+ // Can't do anything with block args without analysis - we don't know if the
+ // value they carry is the last user (move semantics).
+ if (operand.isa<BlockArgument>()) return false;
+
+ // If our value is a constant then we need to ensure that we aren't
+ // tied to a constant operand. If we are we need to clone to a
+ // non-constant value. We could make this work in cases where constants are
+ // being initialized, however those are best modeled as transfer operations
+ // where no mutations will occur on the constant transfer target.
+ if (type.getLifetime() == IREE::Stream::Lifetime::Constant) return false;
+
+ // If there's more than one user we can't make a local decision. It's
+ // expensive to query relative operation order within a block and within a
+ // region the lifetime of values may vary - all things we can't tell here.
+ if (!operand.hasOneUse()) return false;
+
+ // We are the only user and the value is contained entirely within the
+ // current region. We by construction know we do not need to worry.
+ return true;
+}
+
+// Materializes a copy for a mutated |operand| on |affinity| if required.
+// If it's determined that eliding the copy is safe it will be omitted.
+// Returns true if the copy was required and materialized.
+static bool materializeOperandCOW(Location loc, OpOperand &operand,
+ IREE::Stream::AffinityAttr affinity,
+ OpBuilder &builder) {
+ // If we can safely elide the copy early we do so here to avoid adding too
+ // much IR. Anything that requires wider analysis (CFG, across functions, etc)
+ // has to wait until a subsequent pass.
+ auto resourceType =
+ operand.get().getType().dyn_cast<IREE::Stream::ResourceType>();
+ if (!resourceType) return false;
+ if (isSafeToElideCOW(operand.get(), resourceType)) return false;
+
+ // Materialize a clone operation just for the operand provided.
+ auto sizeAwareType = resourceType.cast<IREE::Util::SizeAwareTypeInterface>();
+ auto size = sizeAwareType.queryValueSize(loc, operand.get(), builder);
+ auto cloneOp = builder.create<IREE::Stream::AsyncCloneOp>(
+ loc, resourceType, operand.get(), size, size, affinity);
+ operand.set(cloneOp.result());
+ return true;
+}
+
+// Materializes a copy for each mutated operand on |tiedOp| as required.
+// Returns true if any copy was required and materialized.
+static bool materializeTiedOpCOW(IREE::Util::TiedOpInterface tiedOp) {
+ bool didChange = false;
+
+ // Any ops we materialize must have the same affinity as their consumer. This
+ // ensures the copies we issue happen locally to the consumer.
+ IREE::Stream::AffinityAttr affinity;
+ if (auto affinityOp =
+ dyn_cast<IREE::Stream::AffinityOpInterface>(tiedOp.getOperation())) {
+ affinity = affinityOp.getAffinity();
+ }
+
+ // Clones each operand that is tied to a result and it may be required.
+ OpBuilder builder(tiedOp);
+ unsigned tiedOperandsOffset = tiedOp.getTiedOperandsIndexAndLength().first;
+ auto tiedOperandIndices = tiedOp.getTiedResultOperandIndices();
+ for (unsigned i = 0; i < tiedOperandIndices.size(); ++i) {
+ int64_t operandIdx = tiedOperandIndices[i];
+ if (operandIdx == IREE::Util::TiedOpInterface::kUntiedIndex) continue;
+ auto &operand = tiedOp->getOpOperand(tiedOperandsOffset + operandIdx);
+ didChange =
+ materializeOperandCOW(tiedOp.getLoc(), operand, affinity, builder) ||
+ didChange;
+ }
+
+ return didChange;
+}
+
+// Materializes copies on writes within |region|.
+// Returns true if any copy was required and materialized.
+static bool materializeRegionCOW(Region ®ion) {
+ bool didChange = false;
+ for (auto &block : region.getBlocks()) {
+ for (auto &op : block) {
+ if (!op.hasTrait<OpTrait::IREE::Stream::AsyncPhaseOp>()) continue;
+ didChange =
+ TypeSwitch<Operation *, bool>(&op)
+ .Case<IREE::Stream::TensorImportOp, IREE::Stream::TensorExportOp,
+ IREE::Stream::AsyncFillOp, IREE::Stream::AsyncUpdateOp,
+ IREE::Stream::AsyncCopyOp, IREE::Stream::AsyncDispatchOp,
+ IREE::Stream::AsyncExecuteOp,
+ IREE::Stream::AsyncConcurrentOp>(
+ [&](auto op) { return materializeTiedOpCOW(op); })
+ .Default(false) ||
+ didChange;
+ }
+ }
+ return didChange;
+}
+
+//===----------------------------------------------------------------------===//
+// -iree-stream-materialize-copy-on-write
+//===----------------------------------------------------------------------===//
+
+// Applies a relatively simple heuristic to insert copies where they _may_ be
+// required. This may introduce copies that are not required for the sake of
+// ensuring correctness. Intended to be paired with
+// -iree-stream-elide-async-copies.
+//
+// Conceptually this work is performed in two phases: copy insertion and copy
+// elision. This pass inserts copies at all mutation sites regardless of whether
+// they are required, effectively disabling ties as a mechanism for in-place
+// updates but ensuring correct execution semantics. Afterward a dataflow
+// analysis pass is run to identify which copies can be elided based on use-def
+// chains (including ones spanning the CFG). Though this process can lead to
+// additional copies it is easier to ensure that each pass works independently
+// and also makes it easy to disable copy elision to ferret out issues.
+class MaterializeCopyOnWritePass
+ : public MaterializeCopyOnWriteBase<MaterializeCopyOnWritePass> {
+ public:
+ MaterializeCopyOnWritePass() = default;
+
+ void getDependentDialects(DialectRegistry ®istry) const override {
+ registry.insert<mlir::StandardOpsDialect>();
+ registry.insert<mlir::arith::ArithmeticDialect>();
+ registry.insert<IREE::Stream::StreamDialect>();
+ registry.insert<IREE::Util::UtilDialect>();
+ }
+
+ void runOnOperation() override {
+ bool didChange = false;
+ for (auto ®ion : getOperation()->getRegions()) {
+ didChange = materializeRegionCOW(region) || didChange;
+ }
+ // TODO(benvanik): run canonicalization patterns inline if anything changed.
+ (void)didChange;
+ }
+};
+
+} // namespace
+
+std::unique_ptr<OperationPass<>> createMaterializeCopyOnWritePass() {
+ return std::make_unique<MaterializeCopyOnWritePass>();
+}
+
+} // namespace Stream
+} // namespace IREE
+} // namespace iree_compiler
+} // namespace mlir
diff --git a/iree/compiler/Dialect/Stream/Transforms/Passes.cpp b/iree/compiler/Dialect/Stream/Transforms/Passes.cpp
index 9f771a0..9f648e5 100644
--- a/iree/compiler/Dialect/Stream/Transforms/Passes.cpp
+++ b/iree/compiler/Dialect/Stream/Transforms/Passes.cpp
@@ -113,6 +113,17 @@
passManager.addNestedPass<mlir::FuncOp>(
IREE::Stream::createEncodeTensorsPass());
addCleanupPatterns(passManager);
+
+ // This will insert a lot of copies, so follow it up with a pass that elides
+ // ones that aren't needed. This is easier to verify than if there was one
+ // pass attempting to do both. Note that copy-on-write materialization is
+ // required for correct execution while copy elision is for performance only
+ // (though it's critical enough that it is not optional).
+ passManager.addNestedPass<IREE::Util::InitializerOp>(
+ IREE::Stream::createMaterializeCopyOnWritePass());
+ passManager.addNestedPass<mlir::FuncOp>(
+ IREE::Stream::createMaterializeCopyOnWritePass());
+ passManager.addPass(IREE::Stream::createElideAsyncCopiesPass());
}
//===----------------------------------------------------------------------===//
diff --git a/iree/compiler/Dialect/Stream/Transforms/Passes.h b/iree/compiler/Dialect/Stream/Transforms/Passes.h
index 7c411a6..4346b71 100644
--- a/iree/compiler/Dialect/Stream/Transforms/Passes.h
+++ b/iree/compiler/Dialect/Stream/Transforms/Passes.h
@@ -84,6 +84,8 @@
//===----------------------------------------------------------------------===//
std::unique_ptr<OperationPass<>> createEncodeTensorsPass();
+std::unique_ptr<OperationPass<>> createMaterializeCopyOnWritePass();
+std::unique_ptr<OperationPass<mlir::ModuleOp>> createElideAsyncCopiesPass();
//===----------------------------------------------------------------------===//
// Diagnostics
diff --git a/iree/compiler/Dialect/Stream/Transforms/Passes.td b/iree/compiler/Dialect/Stream/Transforms/Passes.td
index 1e08f95..0273a39 100644
--- a/iree/compiler/Dialect/Stream/Transforms/Passes.td
+++ b/iree/compiler/Dialect/Stream/Transforms/Passes.td
@@ -45,6 +45,22 @@
}];
}
+def MaterializeCopyOnWrite :
+ Pass<"iree-stream-materialize-copy-on-write", ""> {
+ let summary = "Materializes copy-on-write (🐄) behavior as explicit ops.";
+ let constructor = [{
+ mlir::iree_compiler::IREE::Stream::createMaterializeCopyOnWritePass()
+ }];
+}
+
+def ElideAsyncCopies :
+ Pass<"iree-stream-elide-async-copies", "mlir::ModuleOp"> {
+ let summary = "Elides copies when they are not performing meaningful work.";
+ let constructor = [{
+ mlir::iree_compiler::IREE::Stream::createElideAsyncCopiesPass()
+ }];
+}
+
//===----------------------------------------------------------------------===//
// Diagnostics
//===----------------------------------------------------------------------===//
diff --git a/iree/compiler/Dialect/Stream/Transforms/test/BUILD b/iree/compiler/Dialect/Stream/Transforms/test/BUILD
index 2af6a6d..dbf3f58 100644
--- a/iree/compiler/Dialect/Stream/Transforms/test/BUILD
+++ b/iree/compiler/Dialect/Stream/Transforms/test/BUILD
@@ -18,7 +18,9 @@
srcs = enforce_glob(
[
"convert_to_stream.mlir",
+ "elide_async_copies.mlir",
"encode_tensors.mlir",
+ "materialize_copy_on_write.mlir",
"outline_constants.mlir",
],
include = ["*.mlir"],
diff --git a/iree/compiler/Dialect/Stream/Transforms/test/CMakeLists.txt b/iree/compiler/Dialect/Stream/Transforms/test/CMakeLists.txt
index adcc6db..ba24858 100644
--- a/iree/compiler/Dialect/Stream/Transforms/test/CMakeLists.txt
+++ b/iree/compiler/Dialect/Stream/Transforms/test/CMakeLists.txt
@@ -15,7 +15,9 @@
lit
SRCS
"convert_to_stream.mlir"
+ "elide_async_copies.mlir"
"encode_tensors.mlir"
+ "materialize_copy_on_write.mlir"
"outline_constants.mlir"
DATA
iree::tools::IreeFileCheck
diff --git a/iree/compiler/Dialect/Stream/Transforms/test/elide_async_copies.mlir b/iree/compiler/Dialect/Stream/Transforms/test/elide_async_copies.mlir
new file mode 100644
index 0000000..200512b
--- /dev/null
+++ b/iree/compiler/Dialect/Stream/Transforms/test/elide_async_copies.mlir
@@ -0,0 +1,122 @@
+// RUN: iree-opt -split-input-file -iree-stream-elide-async-copies %s | IreeFileCheck %s
+
+// Tests that a normal clone-on-multiple-uses pattern has the last clone elided.
+// This is what the -iree-stream-materialize-copy-on-write pass generates and
+// expects us to clean up.
+
+// CHECK-LABEL: @multiUseTiedOperand
+func @multiUseTiedOperand(%size: index) -> (!stream.resource<*>, !stream.resource<*>) {
+ %c0 = arith.constant 0 : index
+ %c128 = arith.constant 128 : index
+ %c256 = arith.constant 256 : index
+ %c123_i32 = arith.constant 123 : i32
+ %c456_i32 = arith.constant 456 : i32
+ %c789_i32 = arith.constant 789 : i32
+ // CHECK: %[[SPLAT:.+]] = stream.async.splat
+ %splat = stream.async.splat %c123_i32 : i32 -> !stream.resource<*>{%size}
+ // CHECK: %[[CLONE0:.+]] = stream.async.clone %[[SPLAT]]
+ %clone0 = stream.async.clone %splat : !stream.resource<*>{%size} -> !stream.resource<*>{%size}
+ // CHECK: %[[FILL0:.+]] = stream.async.fill %c456_i32, %[[CLONE0]]
+ %fill0 = stream.async.fill %c456_i32, %clone0[%c0 to %c128 for %c128] : i32 -> %1 as !stream.resource<*>{%size}
+ // CHECK-NOT: stream.async.clone
+ %clone1 = stream.async.clone %splat : !stream.resource<*>{%size} -> !stream.resource<*>{%size}
+ // CHECK: %[[FILL1:.+]] = stream.async.fill %c789_i32, %[[SPLAT]]
+ %fill1 = stream.async.fill %c789_i32, %clone1[%c128 to %c256 for %c128] : i32 -> %3 as !stream.resource<*>{%size}
+ // CHECK: return %[[FILL0]], %[[FILL1]]
+ return %fill0, %fill1 : !stream.resource<*>, !stream.resource<*>
+}
+
+// -----
+
+// Tests a copy of a by-value function argument gets elided.
+// Since the caller passes in the last live reference the callee is allowed to
+// mutate the memory in-place.
+
+// CHECK-LABEL: @argMoveCallee
+// CHECK-SAME: (%[[ARG0:.+]]: !stream.resource<*>
+func private @argMoveCallee(%arg: !stream.resource<*>, %size: index) -> !stream.resource<*> {
+ %c0 = arith.constant 0 : index
+ %c128 = arith.constant 128 : index
+ %c123_i32 = arith.constant 123 : i32
+ // CHECK-NOT: stream.async.clone
+ %clone = stream.async.clone %arg : !stream.resource<*>{%size} -> !stream.resource<*>{%size}
+ // CHECK: %[[FILL:.+]] = stream.async.fill %c123_i32, %[[ARG0]]
+ %fill = stream.async.fill %c123_i32, %clone[%c0 to %c128 for %c128] : i32 -> %0 as !stream.resource<*>{%size}
+ // CHECK: return %[[FILL]]
+ return %fill : !stream.resource<*>
+}
+// CHECK: @argMoveCaller
+func @argMoveCaller(%size: index) -> !stream.resource<*> {
+ %c123_i32 = arith.constant 123 : i32
+ // CHECK: stream.async.splat
+ %splat = stream.async.splat %c123_i32 : i32 -> !stream.resource<*>{%size}
+ %result = call @argMoveCallee(%splat, %size) : (!stream.resource<*>, index) -> !stream.resource<*>
+ return %result : !stream.resource<*>
+}
+
+// -----
+
+// Tests a copy we cannot elide because the function argument is used after the
+// call and passed by const-reference.
+
+// CHECK-LABEL: @argCopyCallee
+func private @argCopyCallee(%arg: !stream.resource<*>, %size: index) -> !stream.resource<*> {
+ %c0 = arith.constant 0 : index
+ %c128 = arith.constant 128 : index
+ %c123_i32 = arith.constant 123 : i32
+ // CHECK: stream.async.clone
+ %clone = stream.async.clone %arg : !stream.resource<*>{%size} -> !stream.resource<*>{%size}
+ // CHECK: stream.async.fill
+ %fill = stream.async.fill %c123_i32, %clone[%c0 to %c128 for %c128] : i32 -> %0 as !stream.resource<*>{%size}
+ return %fill : !stream.resource<*>
+}
+// CHECK: @argCopyCaller
+func @argCopyCaller(%size: index) -> (!stream.resource<*>, !stream.resource<*>) {
+ %c123_i32 = arith.constant 123 : i32
+ // CHECK: stream.async.splat
+ %splat = stream.async.splat %c123_i32 : i32 -> !stream.resource<*>{%size}
+ %result = call @argCopyCallee(%splat, %size) : (!stream.resource<*>, index) -> !stream.resource<*>
+ return %splat, %result : !stream.resource<*>, !stream.resource<*>
+}
+
+// -----
+
+// Tests that block arguments that are chained as last-use will get their
+// clones elided while those that are used multiple times will not.
+// The first splat is analyzed to be threaded through as the last possible
+// use each time meaning that it can be mutated in place. The second splat
+// is conditionally chosen to be the initial splat or the new value and as such
+// needs to preserve the copy so the original splat is not mutated.
+
+// CHECK-LABEL: @blockArgMove
+// CHECK-SAME: (%[[COND:.+]]: i1
+func private @blockArgMove(%cond: i1, %size: index) -> (!stream.resource<*>, !stream.resource<*>) {
+ %c0 = arith.constant 0 : index
+ %c128 = arith.constant 128 : index
+ %c123_i32 = arith.constant 123 : i32
+ %c456_i32 = arith.constant 456 : i32
+ // CHECK: %[[SPLAT0:.+]] = stream.async.splat %c123
+ %splat0 = stream.async.splat %c123_i32 : i32 -> !stream.resource<*>{%size}
+ // CHECK: %[[SPLAT1:.+]] = stream.async.splat %c456
+ %splat1 = stream.async.splat %c456_i32 : i32 -> !stream.resource<*>{%size}
+ // CHECK: br ^bb1(%[[SPLAT0]], %[[SPLAT1]]
+ br ^bb1(%splat0, %splat1 : !stream.resource<*>, !stream.resource<*>)
+// CHECK: ^bb1(%[[BB1_ARG0:.+]]: !stream.resource<*>, %[[BB1_ARG1:.+]]: !stream.resource<*>)
+^bb1(%bb1_0: !stream.resource<*>, %bb1_1: !stream.resource<*>):
+ // CHECK-NOT: stream.async.clone
+ %clone0 = stream.async.clone %bb1_0 : !stream.resource<*>{%size} -> !stream.resource<*>{%size}
+ // CHECK: %[[FILL0:.+]] = stream.async.fill %c123_i32, %[[BB1_ARG0]]
+ %fill0 = stream.async.fill %c123_i32, %clone0[%c0 to %c128 for %c128] : i32 -> !stream.resource<*>{%size}
+ // CHECK: %[[CLONE1:.+]] = stream.async.clone %[[BB1_ARG1]]
+ %clone1 = stream.async.clone %bb1_1 : !stream.resource<*>{%size} -> !stream.resource<*>{%size}
+ // CHECK: %[[FILL1:.+]] = stream.async.fill %c456_i32, %[[CLONE1]]
+ %fill1 = stream.async.fill %c456_i32, %clone1[%c0 to %c128 for %c128] : i32 -> !stream.resource<*>{%size}
+ // CHECK: %[[SELECT:.+]] = select %[[COND]], %[[SPLAT1]], %[[FILL1]]
+ %bb1_1_new = select %cond, %splat1, %fill1 : !stream.resource<*>
+ // CHECK: cond_br %[[COND]], ^bb1(%[[FILL0]], %[[SELECT]]
+ // CHECK-SAME: ^bb2(%[[FILL0]], %[[SELECT]]
+ cond_br %cond, ^bb1(%fill0, %bb1_1_new : !stream.resource<*>, !stream.resource<*>),
+ ^bb2(%fill0, %bb1_1_new : !stream.resource<*>, !stream.resource<*>)
+^bb2(%bb2_0: !stream.resource<*>, %bb2_1: !stream.resource<*>):
+ return %bb2_0, %bb2_1 : !stream.resource<*>, !stream.resource<*>
+}
diff --git a/iree/compiler/Dialect/Stream/Transforms/test/materialize_copy_on_write.mlir b/iree/compiler/Dialect/Stream/Transforms/test/materialize_copy_on_write.mlir
new file mode 100644
index 0000000..db38078
--- /dev/null
+++ b/iree/compiler/Dialect/Stream/Transforms/test/materialize_copy_on_write.mlir
@@ -0,0 +1,96 @@
+// RUN: iree-opt -split-input-file -pass-pipeline='builtin.func(iree-stream-materialize-copy-on-write)' %s | IreeFileCheck %s
+
+// Tests that block arguments (including function arguments) are always cloned.
+// Until a whole-program analysis runs we don't know their semantics.
+
+// CHECK-LABEL: @blockArgsNeedCopies
+// CHECK-SAME: (%[[SRC:.+]]: !stream.resource<*>, %[[SIZE:.+]]: index)
+func @blockArgsNeedCopies(%src: !stream.resource<*>, %size: index) -> !stream.resource<*> {
+ %c0 = arith.constant 0 : index
+ %c128 = arith.constant 128 : index
+ %c123_i32 = arith.constant 123 : i32
+ // CHECK: %[[CLONE:.+]] = stream.async.clone %[[SRC]] : !stream.resource<*>{%[[SIZE]]} -> !stream.resource<*>{%[[SIZE]]}
+ // CHECK: %[[FILL:.+]] = stream.async.fill %c123_i32, %[[CLONE]]{{.+}} -> %[[CLONE]]
+ %0 = stream.async.fill %c123_i32, %src[%c0 to %c128 for %c128] : i32 -> %src as !stream.resource<*>{%size}
+ // CHECK: return %[[FILL]]
+ return %0 : !stream.resource<*>
+}
+
+// -----
+
+// Tests that copies are not inserted where they are trivially not needed.
+
+// CHECK-LABEL: @singleUseTiedOperand
+// CHECK-SAME: (%[[SIZE:.+]]: index)
+func @singleUseTiedOperand(%size: index) -> !stream.resource<*> {
+ %c0 = arith.constant 0 : index
+ %c128 = arith.constant 128 : index
+ %c256 = arith.constant 256 : index
+ %c123_i32 = arith.constant 123 : i32
+ %c456_i32 = arith.constant 456 : i32
+ %c789_i32 = arith.constant 789 : i32
+ // CHECK: stream.async.splat
+ %0 = stream.async.splat %c123_i32 : i32 -> !stream.resource<*>{%size}
+ // CHECK-NOT: stream.async.clone
+ // CHECK: stream.async.fill
+ %1 = stream.async.fill %c456_i32, %0[%c0 to %c128 for %c128] : i32 -> %0 as !stream.resource<*>{%size}
+ // CHECK-NOT: stream.async.clone
+ // CHECK: stream.async.fill
+ %2 = stream.async.fill %c789_i32, %1[%c128 to %c256 for %c128] : i32 -> %0 as !stream.resource<*>{%size}
+ return %2 : !stream.resource<*>
+}
+
+// -----
+
+// Tests that copies are inserted when there are multiple uses of a mutated
+// value (in this case, the splat acting as an initializer). The additional
+// copy will be elided with the -iree-stream-elide-async-copies pass.
+
+// CHECK-LABEL: @multiUseTiedOperand
+// CHECK-SAME: (%[[SIZE:.+]]: index)
+func @multiUseTiedOperand(%size: index) -> (!stream.resource<*>, !stream.resource<*>) {
+ %c0 = arith.constant 0 : index
+ %c128 = arith.constant 128 : index
+ %c256 = arith.constant 256 : index
+ %c123_i32 = arith.constant 123 : i32
+ %c456_i32 = arith.constant 456 : i32
+ %c789_i32 = arith.constant 789 : i32
+ // CHECK: %[[SPLAT:.+]] = stream.async.splat
+ %0 = stream.async.splat %c123_i32 : i32 -> !stream.resource<*>{%size}
+ // CHECK: %[[CLONE0:.+]] = stream.async.clone %[[SPLAT]]
+ // CHECK: %[[FILL0:.+]] = stream.async.fill %c456_i32, %[[CLONE0]]
+ %1 = stream.async.fill %c456_i32, %0[%c0 to %c128 for %c128] : i32 -> %0 as !stream.resource<*>{%size}
+ // CHECK: %[[CLONE1:.+]] = stream.async.clone %[[SPLAT]]
+ // CHECK: %[[FILL1:.+]] = stream.async.fill %c789_i32, %[[CLONE1]]
+ %2 = stream.async.fill %c789_i32, %0[%c128 to %c256 for %c128] : i32 -> %0 as !stream.resource<*>{%size}
+ return %1, %2 : !stream.resource<*>, !stream.resource<*>
+}
+
+// -----
+
+// Tests that block args (like function args) are copied until copy elision can
+// take care of them later.
+
+// CHECK-LABEL: @blockArgMove
+func @blockArgMove(%cond: i1, %size: index) -> (!stream.resource<*>, !stream.resource<*>) {
+ %c0 = arith.constant 0 : index
+ %c128 = arith.constant 128 : index
+ %c123_i32 = arith.constant 123 : i32
+ %c456_i32 = arith.constant 456 : i32
+ %splat0 = stream.async.splat %c123_i32 : i32 -> !stream.resource<*>{%size}
+ %splat1 = stream.async.splat %c456_i32 : i32 -> !stream.resource<*>{%size}
+ br ^bb1(%splat0, %splat1 : !stream.resource<*>, !stream.resource<*>)
+// CHECK: ^bb1(%[[ARG0:.+]]: !stream.resource<*>, %[[ARG1:.+]]: !stream.resource<*>)
+^bb1(%bb1_0: !stream.resource<*>, %bb1_1: !stream.resource<*>):
+ // CHECK: %[[CLONE0:.+]] = stream.async.clone %[[ARG0]]
+ // CHECK: stream.async.fill %c123_i32, %[[CLONE0]]
+ %fill0 = stream.async.fill %c123_i32, %bb1_0[%c0 to %c128 for %c128] : i32 -> !stream.resource<*>{%size}
+ // CHECK: %[[CLONE1:.+]] = stream.async.clone %[[ARG1]]
+ // CHECK: stream.async.fill %c456_i32, %[[CLONE1]]
+ %fill1 = stream.async.fill %c456_i32, %bb1_1[%c0 to %c128 for %c128] : i32 -> !stream.resource<*>{%size}
+ %bb1_1_new = select %cond, %splat1, %fill1 : !stream.resource<*>
+ cond_br %cond, ^bb1(%fill0, %bb1_1_new : !stream.resource<*>, !stream.resource<*>),
+ ^bb2(%fill0, %bb1_1_new : !stream.resource<*>, !stream.resource<*>)
+^bb2(%bb2_0: !stream.resource<*>, %bb2_1: !stream.resource<*>):
+ return %bb2_0, %bb2_1 : !stream.resource<*>, !stream.resource<*>
+}
diff --git a/iree/compiler/Dialect/Util/Analysis/DFX/Solver.cpp b/iree/compiler/Dialect/Util/Analysis/DFX/Solver.cpp
index bb6cf3f..d91f7f2 100644
--- a/iree/compiler/Dialect/Util/Analysis/DFX/Solver.cpp
+++ b/iree/compiler/Dialect/Util/Analysis/DFX/Solver.cpp
@@ -15,6 +15,15 @@
namespace iree_compiler {
namespace DFX {
+Solver::~Solver() {
+ // Cleanup all elements; since we allocated them from the bump ptr allocator
+ // they won't have their destructors called otherwise. Some elements may have
+ // their own out-of-band allocations (like DenseMap) that would get leaked.
+ for (auto it : elementMap) {
+ it.second->~AbstractElement();
+ }
+}
+
LogicalResult Solver::run() {
LLVM_DEBUG(llvm::dbgs() << "[Solver] identified and initialized "
<< depGraph.syntheticRoot.deps.size()
diff --git a/iree/compiler/Dialect/Util/Analysis/DFX/Solver.h b/iree/compiler/Dialect/Util/Analysis/DFX/Solver.h
index a0ebe6e..b8b157f 100644
--- a/iree/compiler/Dialect/Util/Analysis/DFX/Solver.h
+++ b/iree/compiler/Dialect/Util/Analysis/DFX/Solver.h
@@ -52,6 +52,7 @@
asmState(explorer.getAsmState()),
allocator(allocator),
depGraph(explorer.getAsmState()) {}
+ ~Solver();
// Initialized explorer for walking the IR.
Explorer &getExplorer() { return explorer; }
diff --git a/iree/compiler/Dialect/Util/Analysis/DFX/State.h b/iree/compiler/Dialect/Util/Analysis/DFX/State.h
index 9520aa6..d92bc60 100644
--- a/iree/compiler/Dialect/Util/Analysis/DFX/State.h
+++ b/iree/compiler/Dialect/Util/Analysis/DFX/State.h
@@ -429,7 +429,7 @@
}
// Maximum number of potential values to be tracked.
- static unsigned maxPotentialValues;
+ static constexpr unsigned maxPotentialValues = 32;
// Returns empty set as the best state of potential values.
static PotentialValuesState getBestState() {
diff --git a/iree/compiler/Dialect/Util/Analysis/Explorer.cpp b/iree/compiler/Dialect/Util/Analysis/Explorer.cpp
index b0c64d0..5c40172 100644
--- a/iree/compiler/Dialect/Util/Analysis/Explorer.cpp
+++ b/iree/compiler/Dialect/Util/Analysis/Explorer.cpp
@@ -28,7 +28,8 @@
: rootOp(rootOp),
asmState(rootOp, OpPrintingFlags().elideLargeElementsAttrs()),
callGraph(rootOp),
- defaultAction(defaultAction) {}
+ defaultAction(defaultAction),
+ analysisManager(rootOp, /*passInstrumentor=*/nullptr) {}
Explorer::~Explorer() = default;
diff --git a/iree/compiler/Dialect/Util/Analysis/Explorer.h b/iree/compiler/Dialect/Util/Analysis/Explorer.h
index d8e094e..06c0858 100644
--- a/iree/compiler/Dialect/Util/Analysis/Explorer.h
+++ b/iree/compiler/Dialect/Util/Analysis/Explorer.h
@@ -15,6 +15,7 @@
#include "mlir/IR/AsmState.h"
#include "mlir/IR/Operation.h"
#include "mlir/IR/SymbolTable.h"
+#include "mlir/Pass/AnalysisManager.h"
#include "mlir/Support/LLVM.h"
namespace mlir {
@@ -119,6 +120,9 @@
// been specified.
void initialize();
+ // Returns a cached analysis manager for the root op.
+ AnalysisManager getAnalysisManager() { return analysisManager; }
+
// Cached information about a global variable.
struct GlobalInfo {
// Global variable definition.
@@ -288,6 +292,7 @@
DenseMap<OperationName, TraversalAction> opActions;
DenseMap<Operation *, GlobalInfo> globalInfos;
+ ModuleAnalysisManager analysisManager;
};
} // namespace iree_compiler
diff --git a/iree/compiler/Dialect/Util/IR/UtilTypes.cpp b/iree/compiler/Dialect/Util/IR/UtilTypes.cpp
index 20d9346..4f39cf1 100644
--- a/iree/compiler/Dialect/Util/IR/UtilTypes.cpp
+++ b/iree/compiler/Dialect/Util/IR/UtilTypes.cpp
@@ -207,7 +207,6 @@
bool detail::isOperandTied(Operation *op, unsigned operandIndex) {
auto tiedOp = dyn_cast<TiedOpInterface>(op);
if (!tiedOp) return false;
- SmallVector<Value> results;
auto tiedIndices = tiedOp.getTiedResultOperandIndices();
for (unsigned i = 0; i < tiedIndices.size(); ++i) {
if (tiedIndices[i] == operandIndex) {