Adding affinity analysis.
This performs whole-program analysis to enable the querying of the
ideal affinity for globals, execution ops, and resources. It can run at
most phases of compilation (including on linalg/flow IR) though it's
primarily used by the stream dialect passes such as conversion.
The `AnnotateAffinitiesPass` has been added to aid debugging and the
compiler `iree-stream-annotate-input-affinities` flag can be used to
turn it on - it has no impact on the program generated but can be useful
if affinity analysis fails during conversion.
diff --git a/compiler/src/iree/compiler/Dialect/Flow/IR/FlowInterfaces.td b/compiler/src/iree/compiler/Dialect/Flow/IR/FlowInterfaces.td
index 5a1227e..dcb0b0f 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/IR/FlowInterfaces.td
+++ b/compiler/src/iree/compiler/Dialect/Flow/IR/FlowInterfaces.td
@@ -9,8 +9,4 @@
include "iree/compiler/Dialect/Util/IR/UtilBase.td"
-//===----------------------------------------------------------------------===//
-// IREE::Flow::StreamableOpInterface
-//===----------------------------------------------------------------------===//
-
#endif // IREE_DIALECT_FLOW_INTERFACES
diff --git a/compiler/src/iree/compiler/Dialect/Stream/Analysis/Affinity.cpp b/compiler/src/iree/compiler/Dialect/Stream/Analysis/Affinity.cpp
new file mode 100644
index 0000000..ac3c166
--- /dev/null
+++ b/compiler/src/iree/compiler/Dialect/Stream/Analysis/Affinity.cpp
@@ -0,0 +1,1050 @@
+// Copyright 2024 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "iree/compiler/Dialect/Stream/Analysis/Affinity.h"
+
+#include <utility>
+
+#include "iree/compiler/Dialect/Stream/IR/StreamDialect.h"
+#include "iree/compiler/Dialect/Stream/IR/StreamOps.h"
+#include "iree/compiler/Dialect/Util/Analysis/DFX/Element.h"
+#include "iree/compiler/Dialect/Util/Analysis/DFX/State.h"
+#include "iree/compiler/Dialect/Util/IR/UtilDialect.h"
+#include "iree/compiler/Dialect/Util/IR/UtilOps.h"
+#include "llvm/ADT/TypeSwitch.h"
+#include "llvm/Support/Debug.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/SCF/IR/SCF.h"
+#include "mlir/IR/Attributes.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/Diagnostics.h"
+
+#define DEBUG_TYPE "iree-util-dfx"
+
+namespace mlir::iree_compiler::IREE::Stream {
+
+//===----------------------------------------------------------------------===//
+// Utilities
+//===----------------------------------------------------------------------===//
+
+static const std::string getAffinitySetAsStr(
+ const DFX::PotentialValuesState<IREE::Stream::AffinityAttr> &state,
+ AsmState &asmState) {
+ std::string str;
+ llvm::raw_string_ostream sstream(str);
+ sstream << "pvs: ";
+ if (state.isValidState()) {
+ sstream << "[";
+ if (state.isUndefContained()) {
+ sstream << "undef, ";
+ }
+ llvm::interleaveComma(state.getAssumedSet(), sstream,
+ [&](IREE::Stream::AffinityAttr value) {
+ cast<Attribute>(value).print(sstream);
+ });
+ sstream << "]";
+ } else {
+ sstream << "(invalid)";
+ }
+ sstream.flush();
+ return str;
+}
+
+//===----------------------------------------------------------------------===//
+// Analysis elements
+//===----------------------------------------------------------------------===//
+
+class ValueProducerAffinityPVS
+ : public DFX::StateWrapper<
+ DFX::PotentialValuesState<IREE::Stream::AffinityAttr>,
+ DFX::ValueElement> {
+public:
+ using BaseType =
+ DFX::StateWrapper<DFX::PotentialValuesState<IREE::Stream::AffinityAttr>,
+ DFX::ValueElement>;
+ using BaseType::BaseType;
+
+ static ValueProducerAffinityPVS &createForPosition(const Position &pos,
+ DFX::Solver &solver) {
+ return *(new (solver.getAllocator()) ValueProducerAffinityPVS(pos));
+ }
+
+ // Identity definitions.
+ const std::string getName() const override {
+ return "ValueProducerAffinityPVS";
+ }
+ const void *getID() const override { return &ID; }
+ static bool classof(const DFX::AbstractElement *element) {
+ return (element->getID() == &ID);
+ }
+ static const char ID;
+
+ const std::string getAsStr(AsmState &asmState) const override {
+ return getAffinitySetAsStr(getState(), asmState);
+ }
+
+private:
+ void initializeValue(Value value, DFX::Solver &solver) override;
+ ChangeStatus updateValue(Value value, DFX::Solver &solver) override;
+ void updateFromUse(Value value, OpOperand &operand, StateType &newState,
+ DFX::Solver &solver);
+
+ // Operations that the value is pinned to.
+ SetVector<Operation *> pinnedOps;
+};
+const char ValueProducerAffinityPVS::ID = 0;
+
+class GlobalAffinityPVS
+ : public DFX::StateWrapper<
+ DFX::PotentialValuesState<IREE::Stream::AffinityAttr>,
+ DFX::TypedOperationElement<IREE::Util::GlobalOpInterface>> {
+public:
+ using BaseType = DFX::StateWrapper<
+ DFX::PotentialValuesState<IREE::Stream::AffinityAttr>,
+ DFX::TypedOperationElement<IREE::Util::GlobalOpInterface>>;
+ using BaseType::BaseType;
+
+ static GlobalAffinityPVS &createForPosition(const Position &pos,
+ DFX::Solver &solver) {
+ return *(new (solver.getAllocator()) GlobalAffinityPVS(pos));
+ }
+
+ // Identity definitions.
+ const std::string getName() const override { return "GlobalAffinityPVS"; }
+ const void *getID() const override { return &ID; }
+ static bool classof(const DFX::AbstractElement *element) {
+ return (element->getID() == &ID);
+ }
+ static const char ID;
+
+ const std::string getAsStr(AsmState &asmState) const override {
+ return getAffinitySetAsStr(getState(), asmState);
+ }
+
+private:
+ void initializeOperation(IREE::Util::GlobalOpInterface globalOp,
+ DFX::Solver &solver) override;
+ ChangeStatus updateOperation(IREE::Util::GlobalOpInterface globalOp,
+ DFX::Solver &solver) override;
+};
+const char GlobalAffinityPVS::ID = 0;
+
+class OpAffinityPVS : public DFX::StateWrapper<
+ DFX::PotentialValuesState<IREE::Stream::AffinityAttr>,
+ DFX::OperationElement> {
+public:
+ using BaseType =
+ DFX::StateWrapper<DFX::PotentialValuesState<IREE::Stream::AffinityAttr>,
+ DFX::OperationElement>;
+ using BaseType::BaseType;
+
+ static OpAffinityPVS &createForPosition(const Position &pos,
+ DFX::Solver &solver) {
+ return *(new (solver.getAllocator()) OpAffinityPVS(pos));
+ }
+
+ // Identity definitions.
+ const std::string getName() const override { return "OpAffinityPVS"; }
+ const void *getID() const override { return &ID; }
+ static bool classof(const DFX::AbstractElement *element) {
+ return (element->getID() == &ID);
+ }
+ static const char ID;
+
+ const std::string getAsStr(AsmState &asmState) const override {
+ return getAffinitySetAsStr(getState(), asmState);
+ }
+
+private:
+ void initializeOperation(Operation *op, DFX::Solver &solver) override;
+ ChangeStatus updateOperation(Operation *op, DFX::Solver &solver) override;
+};
+const char OpAffinityPVS::ID = 0;
+
+//===----------------------------------------------------------------------===//
+// ValueConsumerAffinityPVS
+//===----------------------------------------------------------------------===//
+
+class ValueConsumerAffinityPVS
+ : public DFX::StateWrapper<
+ DFX::PotentialValuesState<IREE::Stream::AffinityAttr>,
+ DFX::ValueElement> {
+public:
+ using BaseType =
+ DFX::StateWrapper<DFX::PotentialValuesState<IREE::Stream::AffinityAttr>,
+ DFX::ValueElement>;
+ using BaseType::BaseType;
+
+ static ValueConsumerAffinityPVS &createForPosition(const Position &pos,
+ DFX::Solver &solver) {
+ return *(new (solver.getAllocator()) ValueConsumerAffinityPVS(pos));
+ }
+
+ // Identity definitions.
+ const std::string getName() const override {
+ return "ValueConsumerAffinityPVS";
+ }
+ const void *getID() const override { return &ID; }
+ static bool classof(const DFX::AbstractElement *element) {
+ return (element->getID() == &ID);
+ }
+ static const char ID;
+
+ const std::string getAsStr(AsmState &asmState) const override {
+ return getAffinitySetAsStr(getState(), asmState);
+ }
+
+private:
+ void initializeValue(Value value, DFX::Solver &solver) override;
+ ChangeStatus updateValue(Value value, DFX::Solver &solver) override;
+ TraversalResult updateFromUse(Value value, OpOperand &operand,
+ StateType &newState, DFX::Solver &solver);
+};
+const char ValueConsumerAffinityPVS::ID = 0;
+
+void ValueConsumerAffinityPVS::initializeValue(Value value,
+ DFX::Solver &solver) {}
+
+ChangeStatus ValueConsumerAffinityPVS::updateValue(Value value,
+ DFX::Solver &solver) {
+ StateType newState;
+ auto traversalResult = TraversalResult::COMPLETE;
+
+ // Walk into all consumers of the SSA value.
+ // Note that we may end up at multiple global stores of different globals
+ // by walking down through calls/branches/etc.
+ traversalResult |= solver.getExplorer().walkTransitiveUses(
+ value,
+ [&](OpOperand &operand) {
+ traversalResult |= updateFromUse(value, operand, newState, solver);
+ return WalkResult::advance();
+ },
+ (TraversalBehavior::DEFAULT | TraversalBehavior::DONT_WALK_TIED_VALUES));
+
+ if (traversalResult == TraversalResult::INCOMPLETE) {
+ // Incomplete traversal because of external call graph edges or pointers.
+ newState.unionAssumedWithUndef();
+ newState.indicatePessimisticFixpoint();
+ }
+ return DFX::clampStateAndIndicateChange(getState(), newState);
+}
+
+TraversalResult ValueConsumerAffinityPVS::updateFromUse(Value value,
+ OpOperand &operand,
+ StateType &newState,
+ DFX::Solver &solver) {
+ // If the value is consumed by an affinity-aware op then we can directly use
+ // the affinity specified on the op. A majority of the values we care about at
+ // the stream level are consumed by affinity-aware ops and earlier in the
+ // pipeline dialects may have transfer ops that define affinities we can
+ // anchor on.
+ if (auto affinityOp =
+ dyn_cast<IREE::Stream::AffinityOpInterface>(operand.getOwner())) {
+ auto opPVS = solver.getElementFor<OpAffinityPVS>(
+ *this, Position::forOperation(operand.getOwner()),
+ DFX::Resolution::REQUIRED);
+ LLVM_DEBUG({
+ llvm::dbgs() << "[ValueConsumerAffinityPVS] value ";
+ value.printAsOperand(llvm::dbgs(), solver.getAsmState());
+ llvm::dbgs() << " affinity using consumer affinity from ";
+ operand.get().printAsOperand(llvm::dbgs(), solver.getAsmState());
+ llvm::dbgs() << " as ";
+ opPVS.print(llvm::dbgs(), solver.getAsmState());
+ llvm::dbgs() << "\n";
+ });
+ newState ^= opPVS;
+ }
+
+ // If the consumer op has the operand tied to one or more results then we walk
+ // through to track the transitive consumers. When this analysis runs we are
+ // usually still prior to baking out copy-on-write behavior so it's possible
+ // that the results of the tied operation end up in different places.
+ if (auto tiedOp = dyn_cast<IREE::Util::TiedOpInterface>(operand.getOwner())) {
+ auto tiedResults = tiedOp.getOperandTiedResults(operand.getOperandNumber());
+ for (auto tiedResult : tiedResults) {
+ auto resultPVS = solver.getElementFor<ValueConsumerAffinityPVS>(
+ *this, Position::forValue(tiedResult), DFX::Resolution::REQUIRED);
+ LLVM_DEBUG({
+ llvm::dbgs() << "[ValueConsumerAffinityPVS] value ";
+ value.printAsOperand(llvm::dbgs(), solver.getAsmState());
+ llvm::dbgs() << " affinity referencing tied operand ";
+ operand.get().printAsOperand(llvm::dbgs(), solver.getAsmState());
+ llvm::dbgs() << " result ";
+ tiedResult.printAsOperand(llvm::dbgs(), solver.getAsmState());
+ llvm::dbgs() << " as ";
+ resultPVS.print(llvm::dbgs(), solver.getAsmState());
+ llvm::dbgs() << "\n";
+ });
+ newState ^= resultPVS;
+ }
+ }
+
+ // Handle consumers that are not affinity aware - this should have any control
+ // flow ops so that we can track values that flow through the program.
+ return TypeSwitch<Operation *, TraversalResult>(operand.getOwner())
+ .Case([&](mlir::arith::SelectOp op) {
+ auto &resultPVS = solver.getElementFor<ValueConsumerAffinityPVS>(
+ *this, Position::forValue(op.getResult()),
+ DFX::Resolution::REQUIRED);
+ newState ^= resultPVS.getState();
+ return TraversalResult::COMPLETE;
+ })
+ .Case([&](mlir::BranchOpInterface op) {
+ return solver.getExplorer().walkOutgoingBranchOperandArguments(
+ op, operand.getOperandNumber(),
+ [&](Block *targetBlock, BlockArgument arg) {
+ auto &argUsage = solver.getElementFor<ValueConsumerAffinityPVS>(
+ *this, Position::forValue(arg), DFX::Resolution::OPTIONAL);
+ newState ^= argUsage;
+ return WalkResult::advance();
+ });
+ })
+ .Case([&](mlir::scf::ForOp op) {
+ if (operand.getOperandNumber() >= op.getNumControlOperands()) {
+ int64_t blockIdx =
+ operand.getOperandNumber() - op.getNumControlOperands();
+ auto &beforeUsage = solver.getElementFor<ValueConsumerAffinityPVS>(
+ *this, Position::forValue(op.getRegionIterArg(blockIdx)),
+ DFX::Resolution::REQUIRED);
+ newState ^= beforeUsage.getState();
+ }
+ return TraversalResult::COMPLETE;
+ })
+ .Case([&](mlir::scf::WhileOp op) {
+ auto &beforeUsage = solver.getElementFor<ValueConsumerAffinityPVS>(
+ *this,
+ Position::forValue(
+ op.getBeforeBody()->getArgument(operand.getOperandNumber())),
+ DFX::Resolution::REQUIRED);
+ newState ^= beforeUsage.getState();
+ return TraversalResult::COMPLETE;
+ })
+ .Case([&](mlir::scf::ConditionOp op) {
+ auto &parentUsage = solver.getElementFor<ValueConsumerAffinityPVS>(
+ *this,
+ Position::forValue(
+ op->getParentOp()->getResult(operand.getOperandNumber() - 1)),
+ DFX::Resolution::REQUIRED);
+ newState ^= parentUsage.getState();
+ if (auto whileOp =
+ dyn_cast_or_null<mlir::scf::WhileOp>(op->getParentOp())) {
+ auto value = Position::forValue(
+ whileOp.getAfter().getArgument(operand.getOperandNumber() - 1));
+ auto &valueUsage = solver.getElementFor<ValueConsumerAffinityPVS>(
+ *this, value, DFX::Resolution::REQUIRED);
+ newState ^= valueUsage.getState();
+ }
+ return TraversalResult::COMPLETE;
+ })
+ .Case([&](mlir::scf::YieldOp op) {
+ if (isa<mlir::scf::IfOp>(op->getParentOp())) {
+ auto &operandUsage = solver.getElementFor<ValueConsumerAffinityPVS>(
+ *this,
+ Position::forValue(op->getOperand(operand.getOperandNumber())),
+ DFX::Resolution::REQUIRED);
+ newState ^= operandUsage.getState();
+ auto &parentUsage = solver.getElementFor<ValueConsumerAffinityPVS>(
+ *this,
+ Position::forValue(
+ op->getParentOp()->getResult(operand.getOperandNumber())),
+ DFX::Resolution::REQUIRED);
+ newState ^= parentUsage.getState();
+ return TraversalResult::COMPLETE;
+ } else if (auto whileOp =
+ dyn_cast<mlir::scf::WhileOp>(op->getParentOp())) {
+ auto value = Position::forValue(
+ whileOp.getBefore().getArgument(operand.getOperandNumber()));
+ auto &valueUsage = solver.getElementFor<ValueConsumerAffinityPVS>(
+ *this, value, DFX::Resolution::REQUIRED);
+ newState ^= valueUsage.getState();
+ auto &parentUsage = solver.getElementFor<ValueConsumerAffinityPVS>(
+ *this,
+ Position::forValue(
+ whileOp->getResult(operand.getOperandNumber())),
+ DFX::Resolution::REQUIRED);
+ newState ^= parentUsage.getState();
+ return TraversalResult::COMPLETE;
+ } else if (auto forOp = dyn_cast<mlir::scf::ForOp>(op->getParentOp())) {
+ auto value = Position::forValue(
+ forOp.getRegionIterArg(operand.getOperandNumber()));
+ auto &valueUsage = solver.getElementFor<ValueConsumerAffinityPVS>(
+ *this, value, DFX::Resolution::REQUIRED);
+ newState ^= valueUsage.getState();
+ auto &parentUsage = solver.getElementFor<ValueConsumerAffinityPVS>(
+ *this,
+ Position::forValue(forOp->getResult(operand.getOperandNumber())),
+ DFX::Resolution::REQUIRED);
+ newState ^= parentUsage.getState();
+ return TraversalResult::COMPLETE;
+ } else {
+ assert(false && "unhandled scf yield parent");
+ return TraversalResult::INCOMPLETE;
+ }
+ })
+ .Case([&](IREE::Util::ReturnOp op) {
+ return solver.getExplorer().walkIncomingCalls(
+ op->getParentOfType<mlir::CallableOpInterface>(),
+ [&](mlir::CallOpInterface callOp) {
+ auto &argUsage = solver.getElementFor<ValueConsumerAffinityPVS>(
+ *this,
+ Position::forValue(
+ callOp->getResult(operand.getOperandNumber())),
+ DFX::Resolution::OPTIONAL);
+ getState() ^= argUsage;
+ return WalkResult::advance();
+ });
+ })
+ .Case([&](IREE::Util::OptimizationBarrierOp op) {
+ auto &resultPVS = solver.getElementFor<ValueConsumerAffinityPVS>(
+ *this, Position::forValue(op.getResult(operand.getOperandNumber())),
+ DFX::Resolution::REQUIRED);
+ newState ^= resultPVS.getState();
+ return TraversalResult::COMPLETE;
+ })
+ .Case([&](IREE::Util::GlobalStoreOpInterface op) {
+ auto *globalInfo =
+ solver.getExplorer().queryGlobalInfoFrom(op.getGlobalName(), op);
+ auto &globalPVS = solver.getElementFor<GlobalAffinityPVS>(
+ *this, Position::forOperation(globalInfo->op),
+ DFX::Resolution::REQUIRED);
+ newState ^= globalPVS.getState();
+ return TraversalResult::COMPLETE;
+ })
+ .Default([&](Operation *op) { return TraversalResult::COMPLETE; });
+}
+
+//===----------------------------------------------------------------------===//
+// ValueProducerAffinityPVS
+//===----------------------------------------------------------------------===//
+
+void ValueProducerAffinityPVS::initializeValue(Value value,
+ DFX::Solver &solver) {
+ solver.getExplorer().walkDefiningOps(value, [&](OpResult result) {
+ if (!isa<IREE::Stream::AffinityTypeInterface>(result.getType())) {
+ return WalkResult::skip();
+ }
+ if (auto affinityOp =
+ dyn_cast_if_present<IREE::Stream::AffinityOpInterface>(
+ result.getOwner())) {
+ if (affinityOp.pinsValueAffinity()) {
+ pinnedOps.insert(result.getOwner());
+ }
+ }
+ return WalkResult::advance();
+ });
+ solver.getExplorer().walkTransitiveUses(value, [&](OpOperand &operand) {
+ if (!isa<IREE::Stream::AffinityTypeInterface>(operand.get().getType())) {
+ return WalkResult::skip();
+ }
+ if (auto affinityOp =
+ dyn_cast_if_present<IREE::Stream::AffinityOpInterface>(
+ operand.getOwner())) {
+ if (affinityOp.pinsValueAffinity()) {
+ pinnedOps.insert(operand.getOwner());
+ }
+ }
+ return WalkResult::advance();
+ });
+}
+
+ChangeStatus ValueProducerAffinityPVS::updateValue(Value value,
+ DFX::Solver &solver) {
+ StateType newState;
+
+ // If there are any ops that produce the value and pin to a specific affinity
+ // then we take those directly and ignore all others.
+ if (!pinnedOps.empty()) {
+ for (auto pinnedOp : pinnedOps) {
+ auto &opPVS = solver.getElementFor<OpAffinityPVS>(
+ *this, Position::forOperation(pinnedOp), DFX::Resolution::REQUIRED);
+ newState ^= opPVS;
+ }
+ return DFX::clampStateAndIndicateChange(getState(), newState);
+ }
+
+ // We special case some ops that act as barriers in the program. This prevents
+ // us from walking past boundaries that are not profitable to do so with; for
+ // example, globals are usually stored in independent contexts from where they
+ // are consumed.
+ if (auto barrierOp = dyn_cast_if_present<IREE::Util::OptimizationBarrierOp>(
+ value.getDefiningOp())) {
+ auto operand =
+ barrierOp.getOperand(cast<OpResult>(value).getResultNumber());
+ auto operandPVS = solver.getElementFor<ValueProducerAffinityPVS>(
+ *this, Position::forValue(operand), DFX::Resolution::REQUIRED);
+ LLVM_DEBUG({
+ llvm::dbgs() << "[ValueProducerAffinityPVS] value ";
+ value.printAsOperand(llvm::dbgs(), solver.getAsmState());
+ llvm::dbgs() << " affinity using barrier op operand as ";
+ operandPVS.print(llvm::dbgs(), solver.getAsmState());
+ llvm::dbgs() << "\n";
+ });
+ newState ^= operandPVS;
+ return DFX::clampStateAndIndicateChange(getState(), newState);
+ } else if (auto loadOp =
+ dyn_cast_if_present<IREE::Util::GlobalLoadOpInterface>(
+ value.getDefiningOp())) {
+ auto *globalInfo = solver.getExplorer().queryGlobalInfoFrom(
+ loadOp.getGlobalName(), loadOp);
+ auto &globalPVS = solver.getElementFor<GlobalAffinityPVS>(
+ *this, Position::forOperation(globalInfo->op),
+ DFX::Resolution::REQUIRED);
+ LLVM_DEBUG({
+ llvm::dbgs() << "[ValueProducerAffinityPVS] value ";
+ value.printAsOperand(llvm::dbgs(), solver.getAsmState());
+ llvm::dbgs() << " affinity using global op affinity from "
+ << loadOp.getGlobalName() << " as ";
+ globalPVS.print(llvm::dbgs(), solver.getAsmState());
+ llvm::dbgs() << "\n";
+ });
+ newState ^= globalPVS.getState();
+ return DFX::clampStateAndIndicateChange(getState(), newState);
+ }
+
+ // Walk the program up into any possible producers of the value.
+ auto traversalResult = TraversalResult::COMPLETE;
+ traversalResult |= solver.getExplorer().walkDefiningOps(
+ value,
+ [&](OpResult result) {
+ if (isa<CallOpInterface>(result.getOwner())) {
+ return WalkResult::advance();
+ }
+
+ // If coming from an affinity-aware op that pins the value storage to a
+ // particular affinity that overrides all other logic.
+ if (auto affinityOp =
+ dyn_cast_if_present<IREE::Stream::AffinityOpInterface>(
+ result.getDefiningOp())) {
+ if (affinityOp.pinsValueAffinity()) {
+ auto &opPVS = solver.getElementFor<OpAffinityPVS>(
+ *this, Position::forOperation(affinityOp),
+ DFX::Resolution::REQUIRED);
+ LLVM_DEBUG({
+ llvm::dbgs() << "[ValueProducerAffinityPVS] value ";
+ value.printAsOperand(llvm::dbgs(), solver.getAsmState());
+ llvm::dbgs() << " affinity using assuming pinned affinity from ";
+ result.printAsOperand(llvm::dbgs(), solver.getAsmState());
+ llvm::dbgs() << " as ";
+ opPVS.print(llvm::dbgs(), solver.getAsmState());
+ llvm::dbgs() << "\n";
+ });
+ newState ^= opPVS;
+ newState.indicateOptimisticFixpoint();
+ return WalkResult::advance();
+ }
+ }
+
+ // If the result value is tied to an operand of the defining op then
+ // inherit the operand affinity.
+ if (auto tiedOp = dyn_cast_if_present<IREE::Util::TiedOpInterface>(
+ result.getDefiningOp())) {
+ auto operand = tiedOp.getTiedResultOperand(result);
+ if (operand) {
+ auto &valuePVS = solver.getElementFor<ValueProducerAffinityPVS>(
+ *this, Position::forValue(operand), DFX::Resolution::OPTIONAL);
+ LLVM_DEBUG({
+ llvm::dbgs() << "[ValueProducerAffinityPVS] value ";
+ value.printAsOperand(llvm::dbgs(), solver.getAsmState());
+ llvm::dbgs() << " affinity referencing tied operand ";
+ operand.printAsOperand(llvm::dbgs(), solver.getAsmState());
+ llvm::dbgs() << " as ";
+ valuePVS.print(llvm::dbgs(), solver.getAsmState());
+ llvm::dbgs() << "\n";
+ });
+ newState ^= valuePVS;
+ return WalkResult::advance();
+ }
+ }
+
+ // If the value is produced by the defining op then assume that the
+ // execution affinity dictates the result affinity.
+ if (auto affinityOp =
+ dyn_cast_if_present<IREE::Stream::AffinityOpInterface>(
+ result.getDefiningOp())) {
+ auto &opPVS = solver.getElementFor<OpAffinityPVS>(
+ *this, Position::forOperation(result.getOwner()),
+ DFX::Resolution::OPTIONAL);
+ LLVM_DEBUG({
+ llvm::dbgs() << "[ValueProducerAffinityPVS] value ";
+ value.printAsOperand(llvm::dbgs(), solver.getAsmState());
+ llvm::dbgs() << " affinity using op affinity from result ";
+ result.printAsOperand(llvm::dbgs(), solver.getAsmState());
+ llvm::dbgs() << " as ";
+ opPVS.print(llvm::dbgs(), solver.getAsmState());
+ llvm::dbgs() << "\n";
+ });
+ newState ^= opPVS;
+ return WalkResult::advance();
+ }
+
+ // Special handling for specific ops.
+ TypeSwitch<Operation *>(result.getOwner())
+ .Case<IREE::Util::GlobalLoadOpInterface>([&](auto loadOp) {
+ auto *globalInfo = solver.getExplorer().queryGlobalInfoFrom(
+ loadOp.getGlobalName(), loadOp);
+ auto &globalPVS = solver.getElementFor<GlobalAffinityPVS>(
+ *this, Position::forOperation(globalInfo->op),
+ DFX::Resolution::REQUIRED);
+ LLVM_DEBUG({
+ llvm::dbgs() << "[ValueProducerAffinityPVS] value ";
+ value.printAsOperand(llvm::dbgs(), solver.getAsmState());
+ llvm::dbgs()
+ << " affinity using global op affinity from result ";
+ result.printAsOperand(llvm::dbgs(), solver.getAsmState());
+ llvm::dbgs() << " as ";
+ globalPVS.print(llvm::dbgs(), solver.getAsmState());
+ llvm::dbgs() << "\n";
+ });
+ newState ^= globalPVS.getState();
+ })
+ .Case<mlir::arith::SelectOp>([&](auto op) {
+ auto &truePVS = solver.getElementFor<ValueProducerAffinityPVS>(
+ *this, Position::forValue(op.getTrueValue()),
+ DFX::Resolution::REQUIRED);
+ newState ^= truePVS.getState();
+ auto &falsePVS = solver.getElementFor<ValueProducerAffinityPVS>(
+ *this, Position::forValue(op.getFalseValue()),
+ DFX::Resolution::REQUIRED);
+ newState ^= falsePVS.getState();
+ })
+ .Default([&](auto op) {
+ auto valuePVS = solver.getElementFor<ValueProducerAffinityPVS>(
+ *this, Position::forValue(result), DFX::Resolution::OPTIONAL);
+ newState ^= valuePVS;
+ });
+ return WalkResult::advance();
+ },
+ (TraversalBehavior::DEFAULT | TraversalBehavior::DONT_WALK_TIED_VALUES));
+
+ if (traversalResult == TraversalResult::INCOMPLETE) {
+ // Incomplete traversal because of external call graph edges or pointers.
+ newState.unionAssumedWithUndef();
+ newState.indicatePessimisticFixpoint();
+ }
+ return DFX::clampStateAndIndicateChange(getState(), newState);
+}
+
+//===----------------------------------------------------------------------===//
+// GlobalAffinityPVS
+//===----------------------------------------------------------------------===//
+
+void GlobalAffinityPVS::initializeOperation(
+ IREE::Util::GlobalOpInterface globalOp, DFX::Solver &solver) {
+ // If an affinity is explicitly specified we take that over all analysis.
+ if (auto affinityAttr = IREE::Stream::AffinityAttr::lookup(globalOp)) {
+ LLVM_DEBUG({
+ llvm::dbgs() << "[GlobalAffinityPVS] global @"
+ << globalOp.getGlobalName().getValue()
+ << " affinity explicitly specified as ";
+ affinityAttr.print(llvm::dbgs());
+ llvm::dbgs() << "\n";
+ });
+ unionAssumed(affinityAttr);
+ indicateOptimisticFixpoint();
+ return;
+ }
+}
+
+ChangeStatus
+GlobalAffinityPVS::updateOperation(IREE::Util::GlobalOpInterface globalOp,
+ DFX::Solver &solver) {
+ StateType newState;
+ auto traversalResult = TraversalResult::COMPLETE;
+
+ const auto *globalInfo = solver.getExplorer().getGlobalInfo(globalOp);
+ if (globalInfo->isIndirect) {
+ traversalResult = TraversalResult::INCOMPLETE;
+ }
+
+ // Traverse all transitive uses of the global.
+ // We try to place globals where they are used as the common case is weights
+ // or parameters that are read more frequently than they are written.
+ // The reasoning is that if there are more writes than reads there's unneeded
+ // work being done and otherwise there's always at least one read per write
+ // or more reads than writes.
+ bool anyLoads = false;
+ for (auto loadOp : globalInfo->getLoads()) {
+ anyLoads = true;
+ auto &valuePVS = solver.getElementFor<ValueConsumerAffinityPVS>(
+ *this, Position::forValue(loadOp.getLoadedGlobalValue()),
+ DFX::Resolution::OPTIONAL);
+ if (valuePVS.isValidState()) {
+ newState ^= valuePVS;
+ }
+ }
+
+ // If there were no loads then take the affinity from stores.
+ // This is not common but can arise in tests or where the globals may be used
+ // to model side-effecting behavior.
+ if (!anyLoads) {
+ for (auto storeOp : globalInfo->getStores()) {
+ auto &valuePVS = solver.getElementFor<ValueProducerAffinityPVS>(
+ *this, Position::forValue(storeOp.getStoredGlobalValue()),
+ DFX::Resolution::OPTIONAL);
+ if (valuePVS.isValidState()) {
+ newState ^= valuePVS;
+ }
+ }
+ }
+
+ if (traversalResult == TraversalResult::INCOMPLETE) {
+ // Incomplete traversal because of external call graph edges or pointers.
+ newState.unionAssumedWithUndef();
+ newState.indicatePessimisticFixpoint();
+ }
+ return DFX::clampStateAndIndicateChange(getState(), newState);
+}
+
+//===----------------------------------------------------------------------===//
+// OpAffinityPVS
+//===----------------------------------------------------------------------===//
+
+void OpAffinityPVS::initializeOperation(Operation *op, DFX::Solver &solver) {
+ // If an affinity is explicitly specified we take that over all analysis.
+ if (auto affinityAttr = IREE::Stream::AffinityAttr::lookup(op)) {
+ LLVM_DEBUG({
+ llvm::dbgs() << "[OpAffinityPVS] op ";
+ op->getName().print(llvm::dbgs());
+ llvm::dbgs() << " affinity explicitly specified as ";
+ affinityAttr.print(llvm::dbgs());
+ llvm::dbgs() << "\n";
+ });
+ unionAssumed(affinityAttr);
+ indicateOptimisticFixpoint();
+ return;
+ }
+}
+
+ChangeStatus OpAffinityPVS::updateOperation(Operation *op,
+ DFX::Solver &solver) {
+ StateType newState;
+
+ const bool consumesAny = llvm::any_of(
+ op->getOperandTypes(), +[](Type type) {
+ return isa<IREE::Stream::AffinityTypeInterface>(type);
+ });
+ if (consumesAny) {
+ for (auto operand : op->getOperands()) {
+ if (isa<IREE::Stream::AffinityTypeInterface>(operand.getType())) {
+ auto valuePVS = solver.getElementFor<ValueProducerAffinityPVS>(
+ *this, Position::forValue(operand), DFX::Resolution::REQUIRED);
+ newState ^= valuePVS;
+ }
+ }
+ } else {
+ for (auto result : op->getResults()) {
+ if (isa<IREE::Stream::AffinityTypeInterface>(result.getType())) {
+ auto valuePVS = solver.getElementFor<ValueConsumerAffinityPVS>(
+ *this, Position::forValue(result), DFX::Resolution::REQUIRED);
+ newState ^= valuePVS;
+ }
+ }
+ }
+
+ return DFX::clampStateAndIndicateChange(getState(), newState);
+}
+
+//===----------------------------------------------------------------------===//
+// AffinityAnalysis
+//===----------------------------------------------------------------------===//
+
+// Tries to find a default affinity specified on an ancestor of |fromOp| and
+// adds it to |affinities|. Returns true if an affinity was found.
+static bool tryLookupDefaultAffinity(
+ Operation *fromOp,
+ SmallVectorImpl<IREE::Stream::AffinityAttr> &affinities) {
+ while (fromOp) {
+ auto affinityAttr = fromOp->getAttrOfType<IREE::Stream::AffinityAttr>(
+ "stream.affinity.default");
+ if (affinityAttr) {
+ affinities.push_back(affinityAttr);
+ return true;
+ }
+ fromOp = fromOp->getParentOp();
+ }
+ return false;
+}
+
+// Returns the first affinity if all affinities are compatible and otherwise
+// returns nullptr.
+static IREE::Stream::AffinityAttr
+trySelectLeadAffinity(ArrayRef<IREE::Stream::AffinityAttr> affinities) {
+ if (affinities.empty()) {
+ return {};
+ }
+ auto leadAffinityAttr = affinities.front();
+ for (size_t i = 1; i < affinities.size(); ++i) {
+ if (!IREE::Stream::AffinityAttr::areCompatible(affinities[i],
+ leadAffinityAttr)) {
+ return {};
+ }
+ }
+ return leadAffinityAttr;
+}
+
+// Sorts |affinities| in the natural affinity sort order.
+// We unfortunately have to do this as the PVS elements we source from are
+// unsorted.
+static void
+sortAffinities(SmallVectorImpl<IREE::Stream::AffinityAttr> &affinities) {
+ // HACK: this should probably do a type id ordering followed by a
+ // type-specific ordering (interface compare method?). We just need this to be
+ // stable as the affinities come from multiple DenseSets that have run-to-run
+ // ordering variance. This is very inefficient but is only used when there are
+ // multiple possible affinities and we try to avoid that anyway.
+ if (affinities.size() <= 1) {
+ return;
+ }
+ llvm::stable_sort(affinities, [](IREE::Stream::AffinityAttr lhs,
+ IREE::Stream::AffinityAttr rhs) {
+ std::string lhsStr;
+ llvm::raw_string_ostream lhsStream(lhsStr);
+ lhs.print(lhsStream);
+ std::string rhsStr;
+ llvm::raw_string_ostream rhsStream(rhsStr);
+ rhs.print(rhsStream);
+ return lhsStr < rhsStr;
+ });
+}
+
+AffinityAnalysis::AffinityAnalysis(Operation *rootOp)
+ : explorer(rootOp, TraversalAction::RECURSE), solver(explorer, allocator) {
+ explorer.setOpInterfaceAction<mlir::FunctionOpInterface>(
+ TraversalAction::RECURSE);
+
+ explorer.setDialectAction<mlir::scf::SCFDialect>(TraversalAction::RECURSE);
+
+ explorer.setDialectAction<IREE::Stream::StreamDialect>(
+ TraversalAction::RECURSE);
+ explorer.setOpAction<IREE::Stream::ExecutableOp>(TraversalAction::IGNORE);
+
+ explorer.initialize();
+}
+
+AffinityAnalysis::~AffinityAnalysis() = default;
+
+IREE::Stream::AffinityAttr
+AffinityAnalysis::lookupGlobalAffinity(Operation *op) {
+ SmallVector<IREE::Stream::AffinityAttr> affinities;
+ if (!tryLookupGlobalAffinity(op, affinities) || affinities.empty()) {
+ return {};
+ }
+ if (affinities.size() == 1) {
+ return affinities.front();
+ }
+ return trySelectLeadAffinity(affinities);
+}
+
+bool AffinityAnalysis::tryLookupGlobalAffinity(
+ Operation *op, SmallVectorImpl<IREE::Stream::AffinityAttr> &affinities) {
+ auto globalPVS =
+ solver.lookupElementFor<GlobalAffinityPVS>(Position::forOperation(op));
+ if (!globalPVS || !globalPVS->isValidState() ||
+ globalPVS->isUndefContained()) {
+ // Analysis failed.
+ return false;
+ }
+ if (globalPVS->getAssumedSet().empty()) {
+ // Analysis completed but no affinity was specified; try to find a default.
+ return tryLookupDefaultAffinity(op, affinities);
+ }
+ for (auto affinityAttr : globalPVS->getAssumedSet()) {
+ affinities.push_back(affinityAttr);
+ }
+ sortAffinities(affinities);
+ return true;
+}
+
+IREE::Stream::AffinityAttr
+AffinityAnalysis::lookupExecutionAffinity(Operation *op) {
+ SmallVector<IREE::Stream::AffinityAttr> affinities;
+ if (!tryLookupExecutionAffinity(op, affinities) || affinities.empty()) {
+ return {};
+ }
+ if (affinities.size() == 1) {
+ return affinities.front();
+ }
+ return trySelectLeadAffinity(affinities);
+}
+
+bool AffinityAnalysis::tryLookupExecutionAffinity(
+ Operation *op, SmallVectorImpl<IREE::Stream::AffinityAttr> &affinities) {
+ auto opPVS =
+ solver.lookupElementFor<OpAffinityPVS>(Position::forOperation(op));
+ if (!opPVS || !opPVS->isValidState() || opPVS->isUndefContained()) {
+ // Analysis failed.
+ return false;
+ }
+ if (opPVS->getAssumedSet().empty()) {
+ // Analysis completed but no affinity was specified; try to find a default.
+ return tryLookupDefaultAffinity(op, affinities);
+ }
+ for (auto affinityAttr : opPVS->getAssumedSet()) {
+ affinities.push_back(affinityAttr);
+ }
+ sortAffinities(affinities);
+ return true;
+}
+
+IREE::Stream::AffinityAttr
+AffinityAnalysis::inferExecutionAffinity(Operation *op) {
+ SmallVector<IREE::Stream::AffinityAttr> affinities;
+ if (!tryInferExecutionAffinity(op, affinities) || affinities.empty()) {
+ return {};
+ }
+ if (affinities.size() == 1) {
+ return affinities.front();
+ }
+ return trySelectLeadAffinity(affinities);
+}
+
+bool AffinityAnalysis::tryInferExecutionAffinity(
+ Operation *op, SmallVectorImpl<IREE::Stream::AffinityAttr> &affinities) {
+ if (auto affinityOp = dyn_cast<IREE::Stream::AffinityOpInterface>(op)) {
+ return tryLookupExecutionAffinity(op, affinities);
+ }
+ DFX::PotentialValuesState<IREE::Stream::AffinityAttr> opPVS;
+ const bool consumesAny = llvm::any_of(
+ op->getOperandTypes(), +[](Type type) {
+ return isa<IREE::Stream::AffinityTypeInterface>(type);
+ });
+ if (consumesAny) {
+ for (auto operand : op->getOperands()) {
+ if (isa<IREE::Stream::AffinityTypeInterface>(operand.getType())) {
+ auto valuePVS = solver.lookupElementFor<ValueProducerAffinityPVS>(
+ Position::forValue(operand), nullptr, DFX::Resolution::REQUIRED);
+ if (valuePVS && valuePVS->isValidState()) {
+ opPVS.unionAssumed(valuePVS->getState());
+ } else {
+ return false;
+ }
+ }
+ }
+ } else {
+ for (auto result : op->getResults()) {
+ if (isa<IREE::Stream::AffinityTypeInterface>(result.getType())) {
+ auto valuePVS = solver.lookupElementFor<ValueConsumerAffinityPVS>(
+ Position::forValue(result), nullptr, DFX::Resolution::REQUIRED);
+ if (valuePVS && valuePVS->isValidState()) {
+ opPVS.unionAssumed(valuePVS->getState());
+ } else {
+ return false;
+ }
+ }
+ }
+ }
+ if (!opPVS.isValidState() || opPVS.isUndefContained()) {
+ // Analysis failed.
+ return false;
+ }
+ if (opPVS.getAssumedSet().empty()) {
+ // Analysis completed but no affinity was specified; try to find a default.
+ return tryLookupDefaultAffinity(op, affinities);
+ }
+ for (auto affinityAttr : opPVS.getAssumedSet()) {
+ affinities.push_back(affinityAttr);
+ }
+ sortAffinities(affinities);
+ return true;
+}
+
+IREE::Stream::AffinityAttr
+AffinityAnalysis::lookupResourceAffinity(Value value) {
+ SmallVector<IREE::Stream::AffinityAttr> affinities;
+ if (!tryLookupResourceAffinity(value, affinities) || affinities.empty()) {
+ return {};
+ }
+ if (affinities.size() == 1) {
+ return affinities.front();
+ }
+ return trySelectLeadAffinity(affinities);
+}
+
+bool AffinityAnalysis::tryLookupResourceAffinity(
+ Value value, SmallVectorImpl<IREE::Stream::AffinityAttr> &affinities) {
+ auto valuePVS = solver.lookupElementFor<ValueProducerAffinityPVS>(
+ Position::forValue(value));
+ if (!valuePVS || !valuePVS->isValidState() || valuePVS->isUndefContained()) {
+ // Analysis failed.
+ return false;
+ }
+ if (valuePVS->getAssumedSet().empty()) {
+ // Analysis completed but no affinity was specified; try to find a default.
+ return tryLookupDefaultAffinity(value.getParentBlock()->getParentOp(),
+ affinities);
+ }
+ for (auto affinityAttr : valuePVS->getAssumedSet()) {
+ affinities.push_back(affinityAttr);
+ }
+ sortAffinities(affinities);
+ return true;
+}
+
+LogicalResult AffinityAnalysis::run() {
+ // Initialize globals so that we can assign them affinity.
+ explorer.forEachGlobal([&](const auto *globalInfo) {
+ if (isa<IREE::Stream::AffinityTypeInterface>(
+ globalInfo->op.getGlobalType())) {
+ solver.getOrCreateElementFor<GlobalAffinityPVS>(
+ Position::forOperation(globalInfo->op));
+ }
+ });
+
+ // Initialize op execution affinities for any ops that use tracked types.
+ //
+ // TODO(benvanik): avoid doing this initialization for the entire module and
+ // instead rely on DFX to automatically populate the required abstract values.
+ // There's some missing logic in the element initialization, though, and by
+ // initializing all values we side-step that and work with test programs that
+ // may not have I/O edges that we could easily latch on to here.
+ explorer.forEachFunctionLikeOp([&](FunctionOpInterface funcOp) {
+ for (auto &block : funcOp.getBlocks()) {
+ for (auto arg : block.getArguments()) {
+ if (isa<IREE::Stream::AffinityTypeInterface>(arg.getType())) {
+ solver.getOrCreateElementFor<ValueProducerAffinityPVS>(
+ Position::forValue(arg));
+ }
+ }
+ }
+ funcOp.walk([&](Operation *op) {
+ if (auto regionOp = dyn_cast<RegionBranchOpInterface>(op)) {
+ for (auto ®ion : regionOp->getRegions()) {
+ for (auto arg : region.getArguments()) {
+ if (isa<IREE::Stream::AffinityTypeInterface>(arg.getType())) {
+ solver.getOrCreateElementFor<ValueProducerAffinityPVS>(
+ Position::forValue(arg));
+ }
+ }
+ }
+ }
+ if (auto affinityOp = dyn_cast<IREE::Stream::AffinityOpInterface>(op)) {
+ solver.getOrCreateElementFor<OpAffinityPVS>(Position::forOperation(op));
+ }
+ for (auto result : op->getResults()) {
+ if (isa<IREE::Stream::AffinityTypeInterface>(result.getType())) {
+ solver.getOrCreateElementFor<ValueProducerAffinityPVS>(
+ Position::forValue(result));
+ }
+ }
+ });
+ });
+
+ if (failed(solver.run())) {
+ return failure(); // did not converge
+ }
+
+ LLVM_DEBUG({
+ llvm::dbgs()
+ << "\n\n[Analysis] affinity analysis results for the whole module:\n";
+ solver.print(llvm::dbgs());
+ llvm::dbgs() << "\n";
+ });
+
+ return success();
+}
+
+} // namespace mlir::iree_compiler::IREE::Stream
diff --git a/compiler/src/iree/compiler/Dialect/Stream/Analysis/Affinity.h b/compiler/src/iree/compiler/Dialect/Stream/Analysis/Affinity.h
new file mode 100644
index 0000000..3642a53
--- /dev/null
+++ b/compiler/src/iree/compiler/Dialect/Stream/Analysis/Affinity.h
@@ -0,0 +1,102 @@
+// Copyright 2024 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#ifndef IREE_COMPILER_DIALECT_STREAM_ANALYSIS_AFFINITY_H_
+#define IREE_COMPILER_DIALECT_STREAM_ANALYSIS_AFFINITY_H_
+
+#include "iree/compiler/Dialect/Stream/IR/StreamTypes.h"
+#include "iree/compiler/Dialect/Util/Analysis/DFX/Solver.h"
+#include "iree/compiler/Dialect/Util/Analysis/Explorer.h"
+#include "mlir/IR/Attributes.h"
+#include "mlir/IR/Diagnostics.h"
+
+namespace mlir::iree_compiler::IREE::Stream {
+
+//===----------------------------------------------------------------------===//
+// Affinity analysis
+//===----------------------------------------------------------------------===//
+
+// Performs whole-program analysis of resource and tensor value affinity.
+// All `!stream.resource` and `tensor` SSA values will be analyzed and their
+// affinities where used will be available for querying via the lookup
+// functions.
+class AffinityAnalysis {
+public:
+ explicit AffinityAnalysis(Operation *rootOp);
+ ~AffinityAnalysis();
+
+ // Runs analysis and populates the resource usage map.
+ // May fail if analysis cannot be completed due to unsupported or unknown IR.
+ LogicalResult run();
+
+ // Returns the affinity of the global |op| based on its loads.
+ // The global storage should be allocated with this affinity and available for
+ // fast access from any compatible affinity.
+ //
+ // If an explicit affinity is provided via a stream.affinity attribute then
+ // that will be used in place of analysis. If there are more than one consumer
+ // (such as multiple loads) with differing affinities or analysis fails then
+ // no affinity is returned. If all affinities are compatible one will be
+ // chosen in an unspecified way.
+ IREE::Stream::AffinityAttr lookupGlobalAffinity(Operation *op);
+
+ // Populates all potential affinities of the global |op| in |affinities|.
+ // Returns false if analysis failed and the set of affinities is unknown.
+ bool tryLookupGlobalAffinity(
+ Operation *op, SmallVectorImpl<IREE::Stream::AffinityAttr> &affinities);
+
+ // Returns the affinity of the executable |op| based on the op-specific rules
+ // as to whether its operands or results control placement. The operation
+ // should be scheduled to execute with this affinity and efficiently consume
+ // or produce resources that share a compatible affinity.
+ //
+ // If an explicit affinity is provided via stream.affinity attrs or the
+ // affinity op interface then that will be used in place of analysis. If there
+ // are multiple possible affinities or analysis fails no affinity is returned.
+ // If all affinities are compatible one will be chosen in an unspecified way.
+ IREE::Stream::AffinityAttr lookupExecutionAffinity(Operation *op);
+
+ // Populates all potential execution affinities of |op| in |affinities|.
+ // Returns false if analysis failed and the set of affinities is unknown.
+ bool tryLookupExecutionAffinity(
+ Operation *op, SmallVectorImpl<IREE::Stream::AffinityAttr> &affinities);
+
+ // Returns the affinity of |op| as if it were executable even if it is not.
+ // This relies on analysis of operands and results having resolved and
+ // otherwise returns nullptr indicating the op has no assumed affinity.
+ IREE::Stream::AffinityAttr inferExecutionAffinity(Operation *op);
+
+ // Populates all inferred potential execution affinities of |op| in
+ // |affinities|. This relies on analysis of operands and results having
+ // resolved and otherwise returns nullptr indicating the op has no assumed
+ // affinity.
+ // Returns false if analysis failed and the set of affinities is unknown.
+ bool tryInferExecutionAffinity(
+ Operation *op, SmallVectorImpl<IREE::Stream::AffinityAttr> &affinities);
+
+ // Returns the affinity of |value| based on its producers.
+ // The resource should be allocated with this affinity and be usable by any
+ // compatible affinity.
+ //
+ // If there are more than one producer of the value (such as multiple callers)
+ // with differing affinities or analysis fails then no affinity is returned.
+ // If all affinities are compatible one will be chosen in an unspecified way.
+ IREE::Stream::AffinityAttr lookupResourceAffinity(Value value);
+
+ // Populates all potential affinities of |value| in |affinities|.
+ // Returns false if analysis failed and the set of affinities is unknown.
+ bool tryLookupResourceAffinity(
+ Value value, SmallVectorImpl<IREE::Stream::AffinityAttr> &affinities);
+
+private:
+ Explorer explorer;
+ llvm::BumpPtrAllocator allocator;
+ DFX::Solver solver;
+};
+
+} // namespace mlir::iree_compiler::IREE::Stream
+
+#endif // IREE_COMPILER_DIALECT_STREAM_ANALYSIS_AFFINITY_H_
diff --git a/compiler/src/iree/compiler/Dialect/Stream/Analysis/BUILD.bazel b/compiler/src/iree/compiler/Dialect/Stream/Analysis/BUILD.bazel
index 4e1421b..3cbb5b5 100644
--- a/compiler/src/iree/compiler/Dialect/Stream/Analysis/BUILD.bazel
+++ b/compiler/src/iree/compiler/Dialect/Stream/Analysis/BUILD.bazel
@@ -15,12 +15,14 @@
iree_compiler_cc_library(
name = "Analysis",
srcs = [
+ "Affinity.cpp",
"Partitioning.cpp",
"Partitioning/ReferencePartitioning.cpp",
"ResourceHazards.cpp",
"ResourceUsage.cpp",
],
hdrs = [
+ "Affinity.h",
"Partitioning.h",
"ResourceHazards.h",
"ResourceUsage.h",
diff --git a/compiler/src/iree/compiler/Dialect/Stream/Analysis/CMakeLists.txt b/compiler/src/iree/compiler/Dialect/Stream/Analysis/CMakeLists.txt
index f1b0fc8..c2dd74c 100644
--- a/compiler/src/iree/compiler/Dialect/Stream/Analysis/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Dialect/Stream/Analysis/CMakeLists.txt
@@ -14,10 +14,12 @@
NAME
Analysis
HDRS
+ "Affinity.h"
"Partitioning.h"
"ResourceHazards.h"
"ResourceUsage.h"
SRCS
+ "Affinity.cpp"
"Partitioning.cpp"
"Partitioning/ReferencePartitioning.cpp"
"ResourceHazards.cpp"
diff --git a/compiler/src/iree/compiler/Dialect/Stream/Analysis/ResourceUsage.cpp b/compiler/src/iree/compiler/Dialect/Stream/Analysis/ResourceUsage.cpp
index 1708782..4ff656c 100644
--- a/compiler/src/iree/compiler/Dialect/Stream/Analysis/ResourceUsage.cpp
+++ b/compiler/src/iree/compiler/Dialect/Stream/Analysis/ResourceUsage.cpp
@@ -17,7 +17,6 @@
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/Support/Debug.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
-#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Builders.h"
@@ -416,10 +415,14 @@
// TODO(benvanik): remove kFavorTransients.
bool isSourceExternal = !sourceUsage.isAssumed(NOT_EXTERNAL);
bool isTargetInternal = isAssumed(NOT_EXTERNAL);
- if (kFavorTransients && isSourceExternal && isTargetInternal) {
+ bool deviceChange =
+ op.getSourceAffinityAttr() != op.getResultAffinityAttr();
+ if ((kFavorTransients || deviceChange) && isSourceExternal &&
+ isTargetInternal) {
LLVM_DEBUG({
- llvm::dbgs() << "[ValueResourceUsage] skipping forward prop of "
- "external into internal: ";
+ llvm::dbgs()
+ << "[ValueResourceUsage] skipping forward prop of external "
+ "into internal due to kFavorTransients/device-change: ";
op.print(llvm::dbgs(), solver.getAsmState());
llvm::dbgs() << "\n";
});
@@ -529,7 +532,6 @@
*this,
Position::forValue(op.getBeforeBody()->getArgument(operandIdx)),
DFX::Resolution::REQUIRED);
-
getState() ^= beforeUsage.getState();
})
.Case([&](mlir::scf::ConditionOp op) {
@@ -562,29 +564,30 @@
Position::forValue(op->getParentOp()->getResult(operandIdx)),
DFX::Resolution::REQUIRED);
getState() ^= parentUsage.getState();
- } else if (auto whileOp =
- dyn_cast_or_null<scf::WhileOp>(op->getParentOp())) {
+ } else if (auto whileOp = dyn_cast<scf::WhileOp>(op->getParentOp())) {
auto value =
Position::forValue(whileOp.getBefore().getArgument(operandIdx));
auto &valueUsage = solver.getElementFor<ValueResourceUsage>(
*this, value, DFX::Resolution::REQUIRED);
getState() ^= valueUsage.getState();
- } else if (auto forOp =
- dyn_cast_or_null<scf::ForOp>(op->getParentOp())) {
+ auto &parentUsage = solver.getElementFor<ValueResourceUsage>(
+ *this, Position::forValue(whileOp->getResult(operandIdx)),
+ DFX::Resolution::REQUIRED);
+ getState() ^= parentUsage.getState();
+ } else if (auto forOp = dyn_cast<scf::ForOp>(op->getParentOp())) {
auto value = Position::forValue(forOp.getRegionIterArg(operandIdx));
auto &valueUsage = solver.getElementFor<ValueResourceUsage>(
*this, value, DFX::Resolution::REQUIRED);
getState() ^= valueUsage.getState();
-
auto &parentUsage = solver.getElementFor<ValueResourceUsage>(
*this, Position::forValue(forOp->getResult(operandIdx)),
DFX::Resolution::REQUIRED);
getState() ^= parentUsage.getState();
} else {
- assert(false && "Unsupported test case");
+ assert(false && "unhandled scf yield parent");
}
})
- .Case([&](mlir::func::ReturnOp op) {
+ .Case([&](IREE::Util::ReturnOp op) {
auto &operandUsage = solver.getElementFor<ValueResourceUsage>(
*this, Position::forValue(op.getOperand(operandIdx)),
DFX::Resolution::REQUIRED);
@@ -734,11 +737,14 @@
// TODO(benvanik): remove kFavorTransients.
bool isSourceInternal = isAssumed(NOT_EXTERNAL);
bool isTargetExternal = !resultUsage.isAssumed(NOT_EXTERNAL);
- if (kFavorTransients && isSourceInternal && isTargetExternal) {
+ bool deviceChange =
+ op.getSourceAffinityAttr() != op.getResultAffinityAttr();
+ if ((kFavorTransients || deviceChange) && isSourceInternal &&
+ isTargetExternal) {
LLVM_DEBUG({
llvm::dbgs()
<< "[ValueResourceUsage] skipping back prop of external into "
- "internal due to kFavorTransients: ";
+ "internal due to kFavorTransients/device-change: ";
op.print(llvm::dbgs(), solver.getAsmState());
llvm::dbgs() << "\n";
});
diff --git a/compiler/src/iree/compiler/Dialect/Stream/IR/StreamBase.td b/compiler/src/iree/compiler/Dialect/Stream/IR/StreamBase.td
index bfcca44..4c8fb8d 100644
--- a/compiler/src/iree/compiler/Dialect/Stream/IR/StreamBase.td
+++ b/compiler/src/iree/compiler/Dialect/Stream/IR/StreamBase.td
@@ -504,6 +504,7 @@
}
def Stream_Resource : TypeDef<Stream_Dialect, "Resource", [
+ Stream_AffinityType,
Util_ReferenceType,
Util_SizeAwareType,
DeclareTypeInterfaceMethods<Util_GlobalType, [
diff --git a/compiler/src/iree/compiler/Dialect/Stream/IR/StreamInterfaces.td b/compiler/src/iree/compiler/Dialect/Stream/IR/StreamInterfaces.td
index f17b753..2b686b7 100644
--- a/compiler/src/iree/compiler/Dialect/Stream/IR/StreamInterfaces.td
+++ b/compiler/src/iree/compiler/Dialect/Stream/IR/StreamInterfaces.td
@@ -97,7 +97,13 @@
// Returns an affinity active for the given operation.
// This will recursively walk parent operations until one with the
// `stream.affinity` attribute is found.
- static AffinityAttr lookup(Operation *op);
+ static AffinityAttr lookup(Operation *fromOp);
+
+ // Returns an affinity active for the given operation or the fallback
+ // default if none is specified.
+ // This will recursively walk parent operations until one with the
+ // `stream.affinity` attribute is found.
+ static AffinityAttr lookupOrDefault(Operation *fromOp);
// TODO(benvanik): replace with more fine-grained compatibility checks.
// "Compatible" can mean a lot of things: are they cache-coherent, are they
@@ -116,10 +122,24 @@
}
//===----------------------------------------------------------------------===//
+// IREE::Stream::AffinityTypeInterface
+//===----------------------------------------------------------------------===//
+
+def Stream_AffinityType : TypeInterface<"AffinityTypeInterface"> {
+ let cppNamespace = "::mlir::iree_compiler::IREE::Stream";
+
+ let description = [{
+ Indicates a type represents a resource that has its affinity tracked.
+ }];
+}
+
+//===----------------------------------------------------------------------===//
// IREE::Stream::AffinityOpInterface
//===----------------------------------------------------------------------===//
def Stream_AffinityOp : OpInterface<"AffinityOpInterface"> {
+ let cppNamespace = "::mlir::iree_compiler::IREE::Stream";
+
let description = [{
TBD. Used to denote a stream affinity for ops and specify the kind of
environment the ops are expected run in.
@@ -142,6 +162,19 @@
>,
InterfaceMethod<
/*desc=*/[{
+ Returns true if the operands and results should be pinned to the
+ affinity of the op. This overrides all automatic placement logic.
+ }],
+ /*retTy=*/"bool",
+ /*methodName=*/"pinsValueAffinity",
+ /*args=*/(ins),
+ /*methodBody=*/"",
+ /*defaultImplementation=*/[{
+ return false;
+ }]
+ >,
+ InterfaceMethod<
+ /*desc=*/[{
Returns the stream affinity for the op, indicating where it should run.
}],
/*retTy=*/"IREE::Stream::AffinityAttr",
@@ -149,7 +182,7 @@
/*args=*/(ins),
/*methodBody=*/"",
/*defaultImplementation=*/[{
- return dyn_cast_or_null<IREE::Stream::AffinityAttr>($_self->getAttr("affinity"));
+ return dyn_cast_or_null<IREE::Stream::AffinityAttr>($_op->getAttr("affinity"));
}]
>,
InterfaceMethod<
@@ -161,8 +194,8 @@
/*args=*/(ins "IREE::Stream::AffinityAttr":$value),
/*methodBody=*/"",
/*defaultImplementation=*/[{
- if (value) $_self->setAttr("affinity", value);
- else $_self->removeAttr("affinity");
+ if (value) $_op->setAttr("affinity", value);
+ else $_op->removeAttr("affinity");
}]
>,
];
@@ -173,6 +206,8 @@
//===----------------------------------------------------------------------===//
def Stream_StreamableOp : OpInterface<"StreamableOpInterface"> {
+ let cppNamespace = "::mlir::iree_compiler::IREE::Stream";
+
let description = [{
Interface for ops that can be asynchronous executed in a streaming context.
}];
@@ -212,6 +247,8 @@
//===----------------------------------------------------------------------===//
def Stream_AsyncAccessOp : OpInterface<"AsyncAccessOpInterface"> {
+ let cppNamespace = "::mlir::iree_compiler::IREE::Stream";
+
let description = [{
Interface for stream.async.* ops that access subviews of resources.
This allows for some basic analysis and is only valid prior to allocation.
@@ -240,6 +277,8 @@
//===----------------------------------------------------------------------===//
def Stream_SubviewEffectOp : OpInterface<"SubviewEffectOpInterface"> {
+ let cppNamespace = "::mlir::iree_compiler::IREE::Stream";
+
let description = [{
Interface for ops that operate on subviews of resources used to query the
memory effects for subviews on operands.
@@ -258,6 +297,8 @@
//===----------------------------------------------------------------------===//
def Stream_TimelineOp : OpInterface<"TimelineOpInterface"> {
+ let cppNamespace = "::mlir::iree_compiler::IREE::Stream";
+
let description = [{
Interface for ops that operate in an ordered sequence defined by timepoints.
}];
diff --git a/compiler/src/iree/compiler/Dialect/Stream/IR/StreamOpFolders.cpp b/compiler/src/iree/compiler/Dialect/Stream/IR/StreamOpFolders.cpp
index 84cf1eb..9360eea 100644
--- a/compiler/src/iree/compiler/Dialect/Stream/IR/StreamOpFolders.cpp
+++ b/compiler/src/iree/compiler/Dialect/Stream/IR/StreamOpFolders.cpp
@@ -1231,7 +1231,7 @@
constantOp, constantOp.getResult().getType(), splatOp.getResult(),
resultSize, resultSize,
/*source_affinity=*/constantOp.getAffinityAttr(),
- /*result_affinity=*/nullptr);
+ /*result_affinity=*/constantOp.getAffinityAttr());
return success();
}
};
@@ -1452,9 +1452,9 @@
LogicalResult matchAndRewrite(AsyncConstantOp constantOp,
PatternRewriter &rewriter) const override {
auto value = dyn_cast<ElementsAttr>(constantOp.getValue());
- if (!value || !value.isSplat())
+ if (!value || !value.isSplat()) {
return failure();
-
+ }
auto splatElementAttr =
llvm::dyn_cast<SplatElementsAttr>(value).getSplatValue<TypedAttr>();
auto splatValue = rewriter.create<arith::ConstantOp>(
diff --git a/compiler/src/iree/compiler/Dialect/Stream/IR/StreamOps.cpp b/compiler/src/iree/compiler/Dialect/Stream/IR/StreamOps.cpp
index 698c7b9..358d1fd 100644
--- a/compiler/src/iree/compiler/Dialect/Stream/IR/StreamOps.cpp
+++ b/compiler/src/iree/compiler/Dialect/Stream/IR/StreamOps.cpp
@@ -2020,14 +2020,57 @@
}
IREE::Stream::AffinityAttr AsyncTransferOp::getAffinityAttr() {
- return getResultAffinityAttr();
+ auto sourceType = cast<IREE::Stream::ResourceType>(getSource().getType());
+ auto resultType = cast<IREE::Stream::ResourceType>(getResult().getType());
+ if (sourceType.getLifetime() == IREE::Stream::Lifetime::Staging &&
+ resultType.getLifetime() == IREE::Stream::Lifetime::Staging) {
+ // TODO(multi-device): figure out how to model staging->staging transfers.
+ return getSourceAffinityAttr();
+ } else if (sourceType.getLifetime() == IREE::Stream::Lifetime::Staging) {
+ // If source is staging then the op should execute on the consumer.
+ return getResultAffinityAttr();
+ } else if (resultType.getLifetime() == IREE::Stream::Lifetime::Staging) {
+ // If result is staging then the op should execute on the producer.
+ return getSourceAffinityAttr();
+ } else {
+ // Default to result affinity.
+ return getResultAffinityAttr();
+ }
}
void AsyncTransferOp::setAffinityAttr(IREE::Stream::AffinityAttr value) {
- if (value)
- setResultAffinityAttr(value);
- else
- removeResultAffinityAttr();
+ auto sourceType = cast<IREE::Stream::ResourceType>(getSource().getType());
+ auto resultType = cast<IREE::Stream::ResourceType>(getResult().getType());
+ if (sourceType.getLifetime() == IREE::Stream::Lifetime::Staging &&
+ resultType.getLifetime() == IREE::Stream::Lifetime::Staging) {
+ // TODO(multi-device): figure out how to model staging->staging transfers.
+ if (value) {
+ setSourceAffinityAttr(value);
+ } else {
+ removeSourceAffinityAttr();
+ }
+ } else if (sourceType.getLifetime() == IREE::Stream::Lifetime::Staging) {
+ // If source is staging then the op should execute on the consumer.
+ if (value) {
+ setResultAffinityAttr(value);
+ } else {
+ removeResultAffinityAttr();
+ }
+ } else if (resultType.getLifetime() == IREE::Stream::Lifetime::Staging) {
+ // If result is staging then the op should execute on the producer.
+ if (value) {
+ setSourceAffinityAttr(value);
+ } else {
+ removeSourceAffinityAttr();
+ }
+ } else {
+ // Default to result affinity.
+ if (value) {
+ setResultAffinityAttr(value);
+ } else {
+ removeResultAffinityAttr();
+ }
+ }
}
void AsyncTransferOp::getAsyncAccessRanges(
diff --git a/compiler/src/iree/compiler/Dialect/Stream/IR/StreamOps.td b/compiler/src/iree/compiler/Dialect/Stream/IR/StreamOps.td
index dbe5207..871e3bb 100644
--- a/compiler/src/iree/compiler/Dialect/Stream/IR/StreamOps.td
+++ b/compiler/src/iree/compiler/Dialect/Stream/IR/StreamOps.td
@@ -1520,7 +1520,7 @@
let assemblyFormat = [{
(`on` `(` $affinity^ `)`)?
- $value `,` $target `[` $start_indices `for` $lengths `]` `:`
+ $value `,` $target (`[` $start_indices `for` $lengths^ `]`)? `:`
type($value)
`->`
$target_encoding (`` `{` $target_encoding_dims^ `}`)?
diff --git a/compiler/src/iree/compiler/Dialect/Stream/IR/StreamTypes.cpp b/compiler/src/iree/compiler/Dialect/Stream/IR/StreamTypes.cpp
index 19c2410..82b8609 100644
--- a/compiler/src/iree/compiler/Dialect/Stream/IR/StreamTypes.cpp
+++ b/compiler/src/iree/compiler/Dialect/Stream/IR/StreamTypes.cpp
@@ -335,20 +335,40 @@
// #stream.affinity
//===----------------------------------------------------------------------===//
-AffinityAttr AffinityAttr::lookup(Operation *op) {
- auto attrId = StringAttr::get(op->getContext(), "stream.affinity");
- while (op) {
- if (auto affinityOp = llvm::dyn_cast<AffinityOpInterface>(op)) {
- auto affinity = affinityOp.getAffinityAttr();
- if (affinity)
+// static
+AffinityAttr AffinityAttr::lookup(Operation *fromOp) {
+ auto attrId = StringAttr::get(fromOp->getContext(), "stream.affinity");
+ while (fromOp) {
+ if (auto affinityOp = llvm::dyn_cast<AffinityOpInterface>(fromOp)) {
+ if (auto affinity = affinityOp.getAffinityAttr()) {
return affinity;
+ }
}
- auto attr = op->getAttrOfType<AffinityAttr>(attrId);
- if (attr)
+ if (auto attr = fromOp->getAttrOfType<AffinityAttr>(attrId)) {
return attr;
- op = op->getParentOp();
+ }
+ fromOp = fromOp->getParentOp();
}
- return {}; // No affinity found; let caller decide what to do.
+ // No affinity found; let caller decide what to do.
+ return {};
+}
+
+// static
+AffinityAttr AffinityAttr::lookupOrDefault(Operation *fromOp) {
+ if (auto affinityAttr = AffinityAttr::lookup(fromOp)) {
+ return affinityAttr; // found a specified affinity
+ }
+ auto attrId =
+ StringAttr::get(fromOp->getContext(), "stream.affinity.default");
+ while (fromOp) {
+ if (auto affinityAttr =
+ fromOp->getAttrOfType<IREE::Stream::AffinityAttr>(attrId)) {
+ return affinityAttr;
+ }
+ fromOp = fromOp->getParentOp();
+ }
+ // No affinity or default found; let caller decide what to do.
+ return {};
}
// static
diff --git a/compiler/src/iree/compiler/Dialect/Stream/IR/StreamTypes.h b/compiler/src/iree/compiler/Dialect/Stream/IR/StreamTypes.h
index 42b8424..d69e226 100644
--- a/compiler/src/iree/compiler/Dialect/Stream/IR/StreamTypes.h
+++ b/compiler/src/iree/compiler/Dialect/Stream/IR/StreamTypes.h
@@ -69,9 +69,7 @@
#include "iree/compiler/Dialect/Stream/IR/StreamAttrInterfaces.h.inc" // IWYU pragma: export
-namespace mlir::iree_compiler::IREE::Stream {
#include "iree/compiler/Dialect/Stream/IR/StreamTypeInterfaces.h.inc" // IWYU pragma: export
-} // namespace mlir::iree_compiler::IREE::Stream
// clang-format off: must be included after all LLVM/MLIR headers.
#define GET_TYPEDEF_CLASSES
@@ -99,8 +97,12 @@
const AsyncAccessRange &rhs);
};
+} // namespace mlir::iree_compiler::IREE::Stream
+
#include "iree/compiler/Dialect/Stream/IR/StreamOpInterfaces.h.inc" // IWYU pragma: export
+namespace mlir::iree_compiler::IREE::Stream {
+
//===----------------------------------------------------------------------===//
// custom<ParameterReference>($scope, $key)
//===----------------------------------------------------------------------===//
diff --git a/compiler/src/iree/compiler/Dialect/Stream/Transforms/AnnotateAffinities.cpp b/compiler/src/iree/compiler/Dialect/Stream/Transforms/AnnotateAffinities.cpp
new file mode 100644
index 0000000..62b9db2
--- /dev/null
+++ b/compiler/src/iree/compiler/Dialect/Stream/Transforms/AnnotateAffinities.cpp
@@ -0,0 +1,127 @@
+// Copyright 2024 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "iree/compiler/Dialect/Stream/Analysis/Affinity.h"
+#include "iree/compiler/Dialect/Stream/IR/StreamDialect.h"
+#include "iree/compiler/Dialect/Stream/IR/StreamOps.h"
+#include "iree/compiler/Dialect/Stream/IR/StreamTypes.h"
+#include "iree/compiler/Dialect/Stream/Transforms/Passes.h"
+#include "mlir/IR/Attributes.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/BuiltinOps.h"
+#include "mlir/IR/Matchers.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Pass/Pass.h"
+
+namespace mlir::iree_compiler::IREE::Stream {
+
+#define GEN_PASS_DEF_ANNOTATEAFFINITIESPASS
+#include "iree/compiler/Dialect/Stream/Transforms/Passes.h.inc"
+
+namespace {
+
+//===----------------------------------------------------------------------===//
+// --iree-stream-annotate-affinities
+//===----------------------------------------------------------------------===//
+
+static void annotateOp(Operation *op,
+ ArrayRef<IREE::Stream::AffinityAttr> affinities) {
+ auto affinityOp = dyn_cast<IREE::Stream::AffinityOpInterface>(op);
+ if (!affinityOp || !affinityOp.requiresAffinity()) {
+ return;
+ }
+ if (!affinities.empty()) {
+ op->setAttr("stream.affinities",
+ ArrayAttr::get(op->getContext(),
+ llvm::to_vector_of<Attribute>(affinities)));
+ }
+}
+
+static void annotateGlobalOp(IREE::Util::GlobalOpInterface globalOp,
+ AffinityAnalysis &affinityAnalysis) {
+ if (!isa<IREE::Stream::AffinityTypeInterface>(globalOp.getGlobalType())) {
+ return;
+ }
+ SmallVector<IREE::Stream::AffinityAttr> affinities;
+ if (affinityAnalysis.tryLookupGlobalAffinity(globalOp, affinities)) {
+ annotateOp(globalOp, affinities);
+ }
+}
+
+static void annotateOperandsAndResults(Operation *op,
+ AffinityAnalysis &affinityAnalysis) {
+ auto emptyArray = ArrayAttr::get(op->getContext(), {});
+ SmallVector<Attribute> operandAttrs;
+ for (auto operand : op->getOperands()) {
+ if (isa<IREE::Stream::AffinityTypeInterface>(operand.getType())) {
+ SmallVector<IREE::Stream::AffinityAttr> affinities;
+ if (affinityAnalysis.tryLookupResourceAffinity(operand, affinities)) {
+ operandAttrs.push_back(ArrayAttr::get(
+ op->getContext(), llvm::to_vector_of<Attribute>(affinities)));
+ } else {
+ operandAttrs.push_back(emptyArray);
+ }
+ }
+ }
+ SmallVector<Attribute> resultAttrs;
+ for (auto result : op->getResults()) {
+ if (isa<IREE::Stream::AffinityTypeInterface>(result.getType())) {
+ SmallVector<IREE::Stream::AffinityAttr> affinities;
+ if (affinityAnalysis.tryLookupResourceAffinity(result, affinities)) {
+ resultAttrs.push_back(ArrayAttr::get(
+ op->getContext(), llvm::to_vector_of<Attribute>(affinities)));
+ } else {
+ resultAttrs.push_back(emptyArray);
+ }
+ }
+ }
+ if (!operandAttrs.empty()) {
+ op->setAttr("stream.affinities.operands",
+ ArrayAttr::get(op->getContext(), operandAttrs));
+ }
+ if (!resultAttrs.empty()) {
+ op->setAttr("stream.affinities.results",
+ ArrayAttr::get(op->getContext(), resultAttrs));
+ }
+}
+
+static void annotateFuncOp(FunctionOpInterface funcOp,
+ AffinityAnalysis &affinityAnalysis) {
+ funcOp.walk([&](Operation *op) {
+ SmallVector<IREE::Stream::AffinityAttr> affinities;
+ if (affinityAnalysis.tryLookupExecutionAffinity(op, affinities)) {
+ annotateOp(op, affinities);
+ }
+ annotateOperandsAndResults(op, affinityAnalysis);
+ });
+}
+
+struct AnnotateAffinitiesPass
+ : public IREE::Stream::impl::AnnotateAffinitiesPassBase<
+ AnnotateAffinitiesPass> {
+ void runOnOperation() override {
+ // Run affinity analysis on the whole module.
+ AffinityAnalysis affinityAnalysis(getOperation());
+ if (failed(affinityAnalysis.run())) {
+ return signalPassFailure();
+ }
+
+ // Annotate all ops with derived affinities.
+ for (auto &op : getOperation().getOps()) {
+ if (op.hasTrait<OpTrait::IREE::Util::ObjectLike>())
+ continue;
+ if (auto globalOp = dyn_cast<IREE::Util::GlobalOpInterface>(op)) {
+ annotateGlobalOp(globalOp, affinityAnalysis);
+ } else if (auto funcOp = dyn_cast<FunctionOpInterface>(op)) {
+ annotateFuncOp(funcOp, affinityAnalysis);
+ }
+ }
+ }
+};
+
+} // namespace
+
+} // namespace mlir::iree_compiler::IREE::Stream
diff --git a/compiler/src/iree/compiler/Dialect/Stream/Transforms/BUILD.bazel b/compiler/src/iree/compiler/Dialect/Stream/Transforms/BUILD.bazel
index 1a1d2e2..d2f326c 100644
--- a/compiler/src/iree/compiler/Dialect/Stream/Transforms/BUILD.bazel
+++ b/compiler/src/iree/compiler/Dialect/Stream/Transforms/BUILD.bazel
@@ -15,6 +15,7 @@
iree_compiler_cc_library(
name = "Transforms",
srcs = [
+ "AnnotateAffinities.cpp",
"AnnotateDispatchArguments.cpp",
"ConvertToStream.cpp",
"DumpStatistics.cpp",
diff --git a/compiler/src/iree/compiler/Dialect/Stream/Transforms/CMakeLists.txt b/compiler/src/iree/compiler/Dialect/Stream/Transforms/CMakeLists.txt
index 9d78c8e..5eb3d27 100644
--- a/compiler/src/iree/compiler/Dialect/Stream/Transforms/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Dialect/Stream/Transforms/CMakeLists.txt
@@ -16,6 +16,7 @@
HDRS
"Passes.h"
SRCS
+ "AnnotateAffinities.cpp"
"AnnotateDispatchArguments.cpp"
"ConvertToStream.cpp"
"DumpStatistics.cpp"
diff --git a/compiler/src/iree/compiler/Dialect/Stream/Transforms/Passes.cpp b/compiler/src/iree/compiler/Dialect/Stream/Transforms/Passes.cpp
index b99b792..31a5bb6 100644
--- a/compiler/src/iree/compiler/Dialect/Stream/Transforms/Passes.cpp
+++ b/compiler/src/iree/compiler/Dialect/Stream/Transforms/Passes.cpp
@@ -16,6 +16,12 @@
#include "mlir/Pass/PassRegistry.h"
#include "mlir/Transforms/Passes.h"
+static llvm::cl::opt<bool> clAnnotateInputAffinities(
+ "iree-stream-annotate-input-affinities",
+ llvm::cl::desc("Annotates all tensor/resource affinities on the input to "
+ "the pipeline for debugging."),
+ llvm::cl::init(false));
+
namespace mlir::iree_compiler::IREE::Stream {
using FunctionLikeNest =
@@ -68,6 +74,13 @@
// Conversion
//----------------------------------------------------------------------------
+ // Annotate all ops/resources with the analyzed affinities.
+ // This should have no behavioral changes during conversion but allows for
+ // debugging of analysis errors in end-user tooling.
+ if (clAnnotateInputAffinities) {
+ passManager.addPass(IREE::Stream::createAnnotateAffinitiesPass());
+ }
+
// Converts from all input dialects into various levels of the stream dialect.
// Tensor-like things go to stream.tensor.* ops while lower level buffer-like
// things will go to stream.async.* ops.
@@ -81,6 +94,9 @@
// Constant/variable optimization
//----------------------------------------------------------------------------
+ // Run inlining after having baked out affinities.
+ passManager.addPass(mlir::createInlinerPass());
+
// Cleanup globals that were created during conversion.
addCleanupPatterns(passManager);
@@ -96,10 +112,15 @@
// TODO(benvanik): annotate all dispatches with preferred executable affinity.
// TODO(benvanik): DFA to specify all value affinities and pin dispatches.
+ // TODO(multi-device): it's really nice to be able to verify here but it
+ // prevents compiling to stream without devices specified or continuation at
+ // various phases. It'd be nice to find a way to enable this when the user
+ // expects it to work and otherwise not.
+ //
// Verify that all ops that may require affinities have them assigned or
// available (on a parent scope, etc). This allows subsequent passes to trust
// that an affinity lookup will always return a valid affinity.
- passManager.addPass(IREE::Stream::createVerifyAffinitiesPass());
+ // passManager.addPass(IREE::Stream::createVerifyAffinitiesPass());
}
//===----------------------------------------------------------------------===//
diff --git a/compiler/src/iree/compiler/Dialect/Stream/Transforms/Passes.td b/compiler/src/iree/compiler/Dialect/Stream/Transforms/Passes.td
index ca2ec3a..f5ee39f 100644
--- a/compiler/src/iree/compiler/Dialect/Stream/Transforms/Passes.td
+++ b/compiler/src/iree/compiler/Dialect/Stream/Transforms/Passes.td
@@ -457,6 +457,11 @@
// Diagnostics
//===----------------------------------------------------------------------===//
+def AnnotateAffinitiesPass :
+ Pass<"iree-stream-annotate-affinities", "mlir::ModuleOp"> {
+ let summary = "Annotates affinities on all ops for debugging.";
+}
+
def DumpStatisticsPass :
Pass<"iree-stream-dump-statistics", "mlir::ModuleOp"> {
let summary = "Dumps stream dialect usage information to a file.";
diff --git a/compiler/src/iree/compiler/Dialect/Stream/Transforms/ScheduleAllocation.cpp b/compiler/src/iree/compiler/Dialect/Stream/Transforms/ScheduleAllocation.cpp
index 1bec564..c8510a6 100644
--- a/compiler/src/iree/compiler/Dialect/Stream/Transforms/ScheduleAllocation.cpp
+++ b/compiler/src/iree/compiler/Dialect/Stream/Transforms/ScheduleAllocation.cpp
@@ -668,7 +668,8 @@
return llvm::cast<IREE::Stream::ResourceType>(value.getType())
.getLifetime() == IREE::Stream::Lifetime::Staging;
};
- auto currentAffinityAttr = IREE::Stream::AffinityAttr::lookup(asyncOp);
+ auto currentAffinityAttr =
+ IREE::Stream::AffinityAttr::lookupOrDefault(asyncOp);
bool transferIn = asyncOp.getSourceAffinityAttr() != currentAffinityAttr ||
isStaging(asyncOp.getSource());
bool transferOut = asyncOp.getResultAffinityAttr() != currentAffinityAttr ||
diff --git a/compiler/src/iree/compiler/Dialect/Stream/Transforms/VerifyAffinities.cpp b/compiler/src/iree/compiler/Dialect/Stream/Transforms/VerifyAffinities.cpp
index 042bbb8..7579244 100644
--- a/compiler/src/iree/compiler/Dialect/Stream/Transforms/VerifyAffinities.cpp
+++ b/compiler/src/iree/compiler/Dialect/Stream/Transforms/VerifyAffinities.cpp
@@ -27,7 +27,7 @@
verifyAffinityAssigned(IREE::Stream::AffinityOpInterface op) {
if (!op.requiresAffinity()) {
return success(); // does not require an affinity
- } else if (IREE::Stream::AffinityAttr::lookup(op)) {
+ } else if (IREE::Stream::AffinityAttr::lookupOrDefault(op)) {
return success(); // has an affinity
}
return op->emitOpError()
@@ -55,7 +55,10 @@
return WalkResult::interrupt();
}
}
- return WalkResult::advance();
+ return (op->hasTrait<OpTrait::IREE::Util::ObjectLike>() ||
+ op->hasTrait<OpTrait::SymbolTable>())
+ ? WalkResult::skip()
+ : WalkResult::advance();
})
.wasInterrupted())
return signalPassFailure();
diff --git a/compiler/src/iree/compiler/Dialect/Stream/Transforms/test/BUILD.bazel b/compiler/src/iree/compiler/Dialect/Stream/Transforms/test/BUILD.bazel
index 524a1ce..362d672 100644
--- a/compiler/src/iree/compiler/Dialect/Stream/Transforms/test/BUILD.bazel
+++ b/compiler/src/iree/compiler/Dialect/Stream/Transforms/test/BUILD.bazel
@@ -16,6 +16,7 @@
name = "lit",
srcs = enforce_glob(
[
+ "annotate_affinities.mlir",
"annotate_dispatch_arguments.mlir",
"convert_to_stream.mlir",
"dump_statistics.mlir",
diff --git a/compiler/src/iree/compiler/Dialect/Stream/Transforms/test/CMakeLists.txt b/compiler/src/iree/compiler/Dialect/Stream/Transforms/test/CMakeLists.txt
index 5ea9811..fe83ee6 100644
--- a/compiler/src/iree/compiler/Dialect/Stream/Transforms/test/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Dialect/Stream/Transforms/test/CMakeLists.txt
@@ -14,6 +14,7 @@
NAME
lit
SRCS
+ "annotate_affinities.mlir"
"annotate_dispatch_arguments.mlir"
"convert_to_stream.mlir"
"dump_statistics.mlir"
diff --git a/compiler/src/iree/compiler/Dialect/Stream/Transforms/test/annotate_affinities.mlir b/compiler/src/iree/compiler/Dialect/Stream/Transforms/test/annotate_affinities.mlir
new file mode 100644
index 0000000..c3e1f1e
--- /dev/null
+++ b/compiler/src/iree/compiler/Dialect/Stream/Transforms/test/annotate_affinities.mlir
@@ -0,0 +1,1549 @@
+// RUN: iree-opt --split-input-file --iree-stream-annotate-affinities %s | FileCheck %s
+
+// Tests that we can track affinity through optimization barriers. They're meant
+// to block optimization but we really can't do much if we don't track affinity.
+// We could change this in the future but tests would be harder to write and
+// there's not a lot that can be done with an unassigned resource.
+
+// CHECK-LABEL: @optimization_barrier_consumer
+util.func private @optimization_barrier_consumer() -> tensor<1xi32> {
+ // CHECK: flow.tensor.constant
+ // CHECK-SAME{LITERAL}: stream.affinities = [#hal.device.promise<@dev_a>]
+ // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_a>]]
+ %cst = flow.tensor.constant dense<123> : tensor<1xi32>
+ // CHECK: util.optimization_barrier
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>]]
+ // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_a>]]
+ %cst_dno = util.optimization_barrier %cst : tensor<1xi32>
+ // CHECK: flow.tensor.transfer
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>]]
+ // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_a>]]
+ %cst_a = flow.tensor.transfer %cst_dno : tensor<1xi32> to #hal.device.promise<@dev_a>
+ // CHECK: util.return
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>]]
+ util.return %cst_a : tensor<1xi32>
+}
+
+// -----
+
+// CHECK-LABEL: @optimization_barrier_producer
+util.func private @optimization_barrier_producer() -> tensor<1xi32> {
+ // CHECK: flow.tensor.constant
+ // CHECK-SAME{LITERAL}: stream.affinities = [#hal.device.promise<@dev_a>]
+ // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_a>]]
+ %cst_a = flow.tensor.constant {stream.affinity = #hal.device.promise<@dev_a>} dense<123> : tensor<1xi32>
+ // CHECK: util.optimization_barrier
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>]]
+ // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_a>]]
+ %cst_a_dno = util.optimization_barrier %cst_a : tensor<1xi32>
+ // CHECK: util.return
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>]]
+ util.return %cst_a_dno : tensor<1xi32>
+}
+
+// -----
+
+// Tests that constant-like ops get placed with their consumer(s).
+// We want to replicate constants where they are consumed instead of performing
+// transfers at runtime to move them around and by placing with consumers we
+// can know when we need to do that early on.
+
+// CHECK-LABEL: @constant_op
+util.func private @constant_op() -> (tensor<1xi32>, tensor<1xi32>) {
+ // CHECK: flow.tensor.constant
+ // CHECK-SAME{LITERAL}: stream.affinities = [#hal.device.promise<@dev_a>, #hal.device.promise<@dev_b>]
+ // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_a>, #hal.device.promise<@dev_b>]]
+ %cst = flow.tensor.constant dense<123> : tensor<1xi32>
+ // CHECK: flow.tensor.transfer
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>, #hal.device.promise<@dev_b>]]
+ // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_a>]]
+ %cst_a = flow.tensor.transfer %cst : tensor<1xi32> to #hal.device.promise<@dev_a>
+ // CHECK: flow.tensor.transfer
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>, #hal.device.promise<@dev_b>]]
+ // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_b>]]
+ %cst_b = flow.tensor.transfer %cst : tensor<1xi32> to #hal.device.promise<@dev_b>
+ // CHECK: util.return
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>], [#hal.device.promise<@dev_b>]]
+ util.return %cst_a, %cst_b : tensor<1xi32>, tensor<1xi32>
+}
+
+// -----
+
+// Tests that splats (not constant-like but no consumed values) are placed with
+// their consumer(s). These are always best to rematerialize where they are
+// consumed to avoid allocating/transfering a bunch of repeated values.
+
+// CHECK-LABEL: @splat_op
+util.func private @splat_op() -> tensor<1xi32> {
+ %splat_value = arith.constant 123 : i32
+ // CHECK: flow.tensor.splat
+ // CHECK-SAME{LITERAL}: stream.affinities = [#hal.device.promise<@dev_a>]
+ // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_a>]]
+ %splat = flow.tensor.splat %splat_value : tensor<1xi32>
+ // CHECK: flow.tensor.transfer
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>]]
+ // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_a>]]
+ %splat_a = flow.tensor.transfer %splat : tensor<1xi32> to #hal.device.promise<@dev_a>
+ // CHECK: util.return
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>]]
+ util.return %splat_a : tensor<1xi32>
+}
+
+// -----
+
+// Tests that imported tensor placement is inherited.
+// Frontends can use this to declare where they expect their arguments to
+// be living at the time the functions are invoked. Imports do not perform
+// transfers so we must use whatever is declared.
+
+// CHECK-LABEL: @imported_tensor
+util.func public @imported_tensor(%buffer_view: !hal.buffer_view, %fence: !hal.fence) -> tensor<1xi32> {
+ // CHECK: hal.tensor.import
+ // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_a>]]
+ %tensor = hal.tensor.import on(#hal.device.promise<@dev_a>) wait(%fence) => %buffer_view "input" : !hal.buffer_view -> tensor<1xi32>
+ // CHECK: util.return
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>]]
+ util.return %tensor : tensor<1xi32>
+}
+
+// -----
+
+// Tests that consumer-placed ops exported to buffers are properly placed.
+// Frontends can use this to explicitly define where exported tensors must live.
+// With consumer-placed ops like constants or splats we place them directly on
+// the export target.
+
+// CHECK-LABEL: @exported_constant
+util.func public @exported_constant(%fence: !hal.fence) -> !hal.buffer_view {
+ // CHECK: flow.tensor.constant
+ // CHECK-SAME{LITERAL}: stream.affinities = [#hal.device.promise<@dev_a>]
+ // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_a>]]
+ %cst = flow.tensor.constant dense<123> : tensor<1xi32>
+ // CHECK: hal.tensor.barrier
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>]]
+ // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_a>]]
+ %cst_ready = hal.tensor.barrier join(%cst : tensor<1xi32>) => %fence : !hal.fence
+ // CHECK: hal.tensor.export
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>]]
+ %buffer_view = hal.tensor.export on(#hal.device.promise<@dev_a>) %cst_ready "output" : tensor<1xi32> -> !hal.buffer_view
+ util.return %buffer_view : !hal.buffer_view
+}
+
+// -----
+
+// Tests that producer-placed ops exported to buffers get the appropriate
+// affinity on both devices. Frontends can use this to explicitly define where
+// exported tensors must live. Transfers may need to be inserted in order to
+// respect the required affinities. Note here that the operand to the export
+// is on @dev_a instead of the requested @dev_b.
+
+// CHECK-LABEL: @exported_producer
+util.func public @exported_producer(%fence: !hal.fence) -> !hal.buffer_view {
+ // CHECK: flow.tensor.constant
+ // CHECK-SAME{LITERAL}: stream.affinities = [#hal.device.promise<@dev_a>]
+ // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_a>]]
+ %cst = flow.tensor.constant dense<123> : tensor<1xi32>
+ // CHECK: flow.tensor.transfer
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>]]
+ // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_a>]]
+ %cst_a = flow.tensor.transfer %cst : tensor<1xi32> to #hal.device.promise<@dev_a>
+ // CHECK: flow.tensor.clone
+ // CHECK-SAME{LITERAL}: stream.affinities = [#hal.device.promise<@dev_a>]
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>]]
+ // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_b>]]
+ %clone_a = flow.tensor.clone %cst_a : tensor<1xi32>
+ // CHECK: hal.tensor.barrier
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_b>]]
+ // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_b>]]
+ %clone_ready_a = hal.tensor.barrier join(%clone_a : tensor<1xi32>) => %fence : !hal.fence
+ // CHECK: hal.tensor.export
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_b>]]
+ %buffer_view = hal.tensor.export on(#hal.device.promise<@dev_b>) %clone_ready_a "output" : tensor<1xi32> -> !hal.buffer_view
+ // CHECK: util.return
+ util.return %buffer_view : !hal.buffer_view
+}
+
+// -----
+
+// Test in-place aliased storage for results.
+// Frontends require that the storage be placed as indicated even if that means
+// introducing transfers such that the operation is not in-place.
+
+// CHECK-LABEL: @aliased_storage
+util.func public @aliased_storage(%view: !hal.buffer_view, %storage: !hal.buffer, %fence: !hal.fence) {
+ // CHECK: hal.tensor.import
+ // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_a>]]
+ %arg_a = hal.tensor.import on(#hal.device.promise<@dev_a>) %view : !hal.buffer_view -> tensor<4xi32>
+ // CHECK: flow.dispatch @dispatch
+ // CHECK-SAME{LITERAL}: stream.affinities = [#hal.device.promise<@dev_a>]
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>]]
+ // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_b>]]
+ %ret_b = flow.dispatch @dispatch(%arg_a) : (tensor<4xi32>) -> tensor<4xi32>
+ // CHECK: hal.tensor.alias
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_b>]]
+ // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_b>]]
+ %alias_b = hal.tensor.alias on(#hal.device.promise<@dev_b>) %ret_b : tensor<4xi32> to %storage : !hal.buffer
+ // CHECK: hal.tensor.barrier
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_b>]]
+ // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_b>]]
+ hal.tensor.barrier join(%alias_b : tensor<4xi32>) => %fence : !hal.fence
+ util.return
+}
+
+// -----
+
+// Tests aliased storage through tied dispatches.
+
+// CHECK-LABEL: @tied_aliased_storage
+util.func public @tied_aliased_storage(%view: !hal.buffer_view, %storage: !hal.buffer, %fence: !hal.fence) {
+ // CHECK: flow.tensor.constant
+ // CHECK-SAME{LITERAL}: stream.affinities = [#hal.device.promise<@dev_a>]
+ // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_a>]]
+ %cst = flow.tensor.constant dense<123> : tensor<4xi32>
+ // CHECK: flow.tensor.transfer
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>]]
+ // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_a>]]
+ %cst_a = flow.tensor.transfer %cst : tensor<4xi32> to #hal.device.promise<@dev_a>
+ // CHECK: flow.dispatch @dispatch0
+ // CHECK-SAME{LITERAL}: stream.affinities = [#hal.device.promise<@dev_a>]
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>]]
+ // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_b>]]
+ %t0 = flow.dispatch @dispatch0(%cst) : (tensor<4xi32>) -> tensor<4xi32>
+ // CHECK: flow.dispatch @dispatch1
+ // CHECK-SAME{LITERAL}: stream.affinities = [#hal.device.promise<@dev_b>]
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_b>]]
+ // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_b>]]
+ %t1 = flow.dispatch @dispatch1(%t0) : (tensor<4xi32>) -> %t0
+ // CHECK: hal.tensor.alias
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_b>]]
+ // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_b>]]
+ %alias = hal.tensor.alias on(#hal.device.promise<@dev_b>) %t1 : tensor<4xi32> to %storage : !hal.buffer
+ // CHECK: hal.tensor.barrier
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_b>]]
+ // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_b>]]
+ hal.tensor.barrier join(%alias : tensor<4xi32>) => %fence : !hal.fence
+ util.return
+}
+
+// -----
+
+// Tests that consumer-placed ops that pass through tied ops get attributed to
+// a single consumer.
+
+// CHECK-LABEL: @tied_constant
+util.func private @tied_constant() -> tensor<1xi32> {
+ // CHECK: flow.tensor.constant
+ // CHECK-SAME{LITERAL}: stream.affinities = [#hal.device.promise<@dev_a>]
+ // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_a>]]
+ %cst = flow.tensor.constant dense<123> : tensor<1xi32>
+ // CHECK: flow.dispatch @a
+ // CHECK-SAME{LITERAL}: stream.affinities = [#hal.device.promise<@dev_a>]
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>]]
+ // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_a>]]
+ %tied = flow.dispatch @a(%cst) : (tensor<1xi32>) -> %cst
+ // CHECK: flow.tensor.transfer
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>]]
+ // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_a>]]
+ %tied_a = flow.tensor.transfer %tied : tensor<1xi32> to #hal.device.promise<@dev_a>
+ // CHECK: util.return
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>]]
+ util.return %tied_a : tensor<1xi32>
+}
+
+// -----
+
+// Tests that consumer-placed ops that pass through tied ops get attributed to
+// transitive consumers. This is not ideal but allows the application of
+// replication policies.
+
+// CHECK-LABEL: @tied_constant_multi_consumer
+util.func private @tied_constant_multi_consumer() -> (tensor<1xi32>, tensor<1xi32>) {
+ // CHECK: flow.tensor.constant
+ // CHECK-SAME{LITERAL}: stream.affinities = [#hal.device.promise<@dev_a>, #hal.device.promise<@dev_b>]
+ // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_a>, #hal.device.promise<@dev_b>]]
+ %cst = flow.tensor.constant dense<123> : tensor<1xi32>
+ // CHECK: flow.dispatch @a
+ // CHECK-SAME{LITERAL}: stream.affinities = [#hal.device.promise<@dev_a>, #hal.device.promise<@dev_b>]
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>, #hal.device.promise<@dev_b>]]
+ // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_a>, #hal.device.promise<@dev_b>]]
+ %tied_0 = flow.dispatch @a(%cst) : (tensor<1xi32>) -> %cst
+ // CHECK: flow.tensor.transfer
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>, #hal.device.promise<@dev_b>]]
+ // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_a>]]
+ %tied_0_a = flow.tensor.transfer %tied_0 : tensor<1xi32> to #hal.device.promise<@dev_a>
+ // CHECK: flow.dispatch @b
+ // CHECK-SAME{LITERAL}: stream.affinities = [#hal.device.promise<@dev_a>, #hal.device.promise<@dev_b>]
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>, #hal.device.promise<@dev_b>]]
+ // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_a>, #hal.device.promise<@dev_b>]]
+ %tied_1 = flow.dispatch @b(%cst) : (tensor<1xi32>) -> %cst
+ // CHECK: flow.tensor.transfer
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>, #hal.device.promise<@dev_b>]]
+ // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_b>]]
+ %tied_1_b = flow.tensor.transfer %tied_1 : tensor<1xi32> to #hal.device.promise<@dev_b>
+ // CHECK: util.return
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>], [#hal.device.promise<@dev_b>]]
+ util.return %tied_0_a, %tied_1_b : tensor<1xi32>, tensor<1xi32>
+}
+
+// -----
+
+// Tests the proper transfer of consumer-placed values prior to multiple tied
+// uses don't pollute the execution affinity of ops after transfers. Note that
+// the constant will still have multiple affinities to allow for policies that
+// replicate the constant.
+
+// CHECK-LABEL: @tied_transfer_constant_multi_consumer
+util.func private @tied_transfer_constant_multi_consumer() -> (tensor<1xi32>, tensor<1xi32>) {
+ // CHECK: flow.tensor.constant
+ // CHECK-SAME{LITERAL}: stream.affinities = [#hal.device.promise<@dev_a>, #hal.device.promise<@dev_b>]
+ // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_a>, #hal.device.promise<@dev_b>]]
+ %cst = flow.tensor.constant dense<123> : tensor<1xi32>
+ // CHECK: flow.tensor.transfer
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>, #hal.device.promise<@dev_b>]]
+ // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_a>]]
+ %cst_a = flow.tensor.transfer %cst : tensor<1xi32> to #hal.device.promise<@dev_a>
+ // CHECK: flow.dispatch @a
+ // CHECK-SAME{LITERAL}: stream.affinities = [#hal.device.promise<@dev_a>]
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>]]
+ // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_a>]]
+ %tied_0 = flow.dispatch @a(%cst_a) : (tensor<1xi32>) -> %cst_a
+ // CHECK: flow.tensor.transfer
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>]]
+ // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_a>]]
+ %tied_0_a = flow.tensor.transfer %tied_0 : tensor<1xi32> to #hal.device.promise<@dev_a>
+ // CHECK: flow.tensor.transfer
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>, #hal.device.promise<@dev_b>]]
+ // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_b>]]
+ %cst_b = flow.tensor.transfer %cst : tensor<1xi32> to #hal.device.promise<@dev_b>
+ // CHECK: flow.dispatch @b
+ // CHECK-SAME{LITERAL}: stream.affinities = [#hal.device.promise<@dev_b>]
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_b>]]
+ // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_b>]]
+ %tied_1 = flow.dispatch @b(%cst_b) : (tensor<1xi32>) -> %cst_b
+ // CHECK: flow.tensor.transfer
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_b>]]
+ // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_b>]]
+ %tied_1_b = flow.tensor.transfer %tied_1 : tensor<1xi32> to #hal.device.promise<@dev_b>
+ // CHECK: util.return
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>], [#hal.device.promise<@dev_b>]]
+ util.return %tied_0_a, %tied_1_b : tensor<1xi32>, tensor<1xi32>
+}
+
+// -----
+
+// Tests that implicitly placed consumers use their transfer execution affinity.
+
+// CHECK-LABEL: @transfer_execution_affinity
+util.func private @transfer_execution_affinity() -> tensor<1xi32> {
+ // CHECK: flow.tensor.constant
+ // CHECK-SAME{LITERAL}: stream.affinities = [#hal.device.promise<@dev_a>]
+ // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_a>]]
+ %cst_a = flow.tensor.constant {stream.affinity = #hal.device.promise<@dev_a>} dense<123> : tensor<1xi32>
+ // CHECK: flow.tensor.transfer
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>]]
+ // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_b>]]
+ %cst_b = flow.tensor.transfer %cst_a : tensor<1xi32> to #hal.device.promise<@dev_b>
+ // CHECK: flow.dispatch @dispatch
+ // CHECK-SAME{LITERAL}: stream.affinities = [#hal.device.promise<@dev_b>]
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_b>]]
+ // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_b>]]
+ %dispatch_b = flow.dispatch @dispatch(%cst_b) : (tensor<1xi32>) -> tensor<1xi32>
+ // CHECK: util.return
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_b>]]
+ util.return %dispatch_b : tensor<1xi32>
+}
+
+// -----
+
+// Tests that explicitly placed consumers use their explicit execution affinity.
+
+// CHECK-LABEL: @explicit_execution_affinity
+util.func private @explicit_execution_affinity() -> tensor<1xi32> {
+ // CHECK: flow.tensor.constant
+ // CHECK-SAME{LITERAL}: stream.affinities = [#hal.device.promise<@dev_a>]
+ // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_a>]]
+ %cst_a = flow.tensor.constant {stream.affinity = #hal.device.promise<@dev_a>} dense<123> : tensor<1xi32>
+ // CHECK: flow.dispatch @dispatch
+ // CHECK-SAME{LITERAL}: stream.affinities = [#hal.device.promise<@dev_b>]
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>]]
+ // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_b>]]
+ %dispatch_b = flow.dispatch @dispatch(%cst_a) {stream.affinity = #hal.device.promise<@dev_b>} : (tensor<1xi32>) -> tensor<1xi32>
+ // CHECK: util.return
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_b>]]
+ util.return %dispatch_b : tensor<1xi32>
+}
+
+// -----
+
+// Tests that consumers of operands with multiple affinities inherit those
+// affinities for execution. This allows policies to determine where they want
+// to execute out of the resources they may be consuming.
+
+// CHECK-LABEL: @consume_multi_affinities
+util.func private @consume_multi_affinities() -> tensor<1xi32> {
+ // CHECK: flow.tensor.constant
+ // CHECK-SAME{LITERAL}: stream.affinities = [#hal.device.promise<@dev_a>]
+ // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_a>]]
+ %cst_a = flow.tensor.constant {stream.affinity = #hal.device.promise<@dev_a>} dense<123> : tensor<1xi32>
+ // CHECK: flow.tensor.constant
+ // CHECK-SAME{LITERAL}: stream.affinities = [#hal.device.promise<@dev_b>]
+ // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_b>]]
+ %cst_b = flow.tensor.constant {stream.affinity = #hal.device.promise<@dev_b>} dense<456> : tensor<1xi32>
+ // CHECK: flow.dispatch @dispatch
+ // CHECK-SAME{LITERAL}: stream.affinities = [#hal.device.promise<@dev_a>, #hal.device.promise<@dev_b>]
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>], [#hal.device.promise<@dev_b>]]
+ // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_a>, #hal.device.promise<@dev_b>]]
+ %dispatch_ab = flow.dispatch @dispatch(%cst_a, %cst_b) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32>
+ // CHECK: util.return
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>, #hal.device.promise<@dev_b>]]
+ util.return %dispatch_ab : tensor<1xi32>
+}
+
+// -----
+
+// Tests that globals are placed where they are loaded.
+
+// CHECK: util.global private @consumed_global_a
+// CHECK-SAME{LITERAL}: stream.affinities = [#hal.device.promise<@dev_a>]
+util.global private @consumed_global_a : tensor<1xi32>
+util.func private @consumer_fn() -> tensor<1xi32> {
+ // CHECK: util.global.load @consumed_global_a
+ // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_a>]]
+ %load = util.global.load @consumed_global_a : tensor<1xi32>
+ // CHECK: flow.tensor.transfer
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>]]
+ // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_a>]]
+ %load_a = flow.tensor.transfer %load : tensor<1xi32> to #hal.device.promise<@dev_a>
+ // CHECK: util.return
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>]]
+ util.return %load_a : tensor<1xi32>
+}
+
+// -----
+
+// Tests that a global loaded from two locations is attributed to both
+// affinities. This allows policies to decide whether to replicate the global.
+
+// CHECK: util.global private @consumed_global_ab
+// CHECK-SAME{LITERAL}: stream.affinities = [#hal.device.promise<@dev_a>, #hal.device.promise<@dev_b>]
+util.global private @consumed_global_ab : tensor<1xi32>
+util.func private @consumer_fn_a() -> tensor<1xi32> {
+ // CHECK: util.global.load @consumed_global_ab
+ // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_a>, #hal.device.promise<@dev_b>]]
+ %load = util.global.load @consumed_global_ab : tensor<1xi32>
+ // CHECK: flow.tensor.transfer
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>, #hal.device.promise<@dev_b>]]
+ // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_a>]]
+ %load_a = flow.tensor.transfer %load : tensor<1xi32> to #hal.device.promise<@dev_a>
+ // CHECK: util.return
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>]]
+ util.return %load_a : tensor<1xi32>
+}
+util.func private @consumer_fn_b() -> tensor<1xi32> {
+ // CHECK: util.global.load @consumed_global_ab
+ // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_a>, #hal.device.promise<@dev_b>]]
+ %load = util.global.load @consumed_global_ab : tensor<1xi32>
+ // CHECK: flow.tensor.transfer
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>, #hal.device.promise<@dev_b>]]
+ // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_b>]]
+ %load_b = flow.tensor.transfer %load : tensor<1xi32> to #hal.device.promise<@dev_b>
+ // CHECK: util.return
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_b>]]
+ util.return %load_b : tensor<1xi32>
+}
+
+// -----
+
+// Tests that consumer-placed ops track through global loads.
+
+// CHECK: util.global private mutable @global_b
+// CHECK-SAME{LITERAL}: stream.affinities = [#hal.device.promise<@dev_b>]
+util.global private mutable @global_b : tensor<1xi32>
+util.func private @producer_fn() {
+ // CHECK: flow.tensor.constant
+ // CHECK-SAME{LITERAL}: stream.affinities = [#hal.device.promise<@dev_a>]
+ // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_a>]]
+ %cst_a = flow.tensor.constant {stream.affinity = #hal.device.promise<@dev_a>} dense<123> : tensor<1xi32>
+ // CHECK: util.global.store
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>]]
+ util.global.store %cst_a, @global_b : tensor<1xi32>
+ util.return
+}
+util.func private @consumer_fn() -> tensor<1xi32> {
+ // CHECK: util.global.load
+ // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_b>]]
+ %load = util.global.load @global_b : tensor<1xi32>
+ // CHECK: flow.tensor.transfer
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_b>]]
+ // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_b>]]
+ %load_b = flow.tensor.transfer %load : tensor<1xi32> to #hal.device.promise<@dev_b>
+ // CHECK: util.return
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_b>]]
+ util.return %load_b : tensor<1xi32>
+}
+
+// -----
+
+// Tests that globals that are only stored take the fallback placement of
+// their producer. This is silly but can arise prior to global optimization
+// passes that may elide them.
+
+// CHECK: util.global private mutable @global_a
+// CHECK-SAME{LITERAL}: stream.affinities = [#hal.device.promise<@dev_a>]
+util.global private mutable @global_a : tensor<1xi32>
+util.func private @producer_fn() {
+ // CHECK: flow.tensor.constant
+ // CHECK-SAME{LITERAL}: stream.affinities = [#hal.device.promise<@dev_a>]
+ // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_a>]]
+ %cst_a = flow.tensor.constant {stream.affinity = #hal.device.promise<@dev_a>} dense<123> : tensor<1xi32>
+ // CHECK: util.global.store
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>]]
+ util.global.store %cst_a, @global_a : tensor<1xi32>
+ util.return
+}
+
+// -----
+
+// Tests that global consumers that take on consumed affinity track the global.
+
+// CHECK: util.global private @global_a
+// CHECK-SAME{LITERAL}: stream.affinities = [#hal.device.promise<@dev_a>]
+util.global private @global_a {stream.affinity = #hal.device.promise<@dev_a>} : tensor<1xi32>
+// CHECK: util.global private @global_b
+// CHECK-SAME{LITERAL}: stream.affinities = [#hal.device.promise<@dev_b>]
+util.global private @global_b {stream.affinity = #hal.device.promise<@dev_b>} : tensor<1xi32>
+util.func private @consumer_fn() -> tensor<1xi32> {
+ // CHECK: util.global.load @global_a
+ // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_a>]]
+ %load_a = util.global.load @global_a : tensor<1xi32>
+ // CHECK: util.global.load @global_b
+ // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_b>]]
+ %load_b = util.global.load @global_b : tensor<1xi32>
+ // CHECK: flow.dispatch @dispatch
+ // CHECK-SAME{LITERAL}: stream.affinities = [#hal.device.promise<@dev_a>, #hal.device.promise<@dev_b>]
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>], [#hal.device.promise<@dev_b>]]
+ // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_a>, #hal.device.promise<@dev_b>]]
+ %result_ab = flow.dispatch @dispatch(%load_a, %load_b) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32>
+ // CHECK: util.return
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>, #hal.device.promise<@dev_b>]]
+ util.return %result_ab : tensor<1xi32>
+}
+
+// -----
+
+// Tests a global update tick that operates on the global from multiple
+// affinities.
+
+// CHECK: util.global private mutable @global_a
+// CHECK-SAME{LITERAL}: stream.affinities = [#hal.device.promise<@dev_a>]
+util.global private mutable @global_a {stream.affinity = #hal.device.promise<@dev_a>} = dense<123> : tensor<1xi32>
+util.func private @step(%arg0: tensor<2xi32>) -> tensor<2xi32> {
+ // CHECK: util.global.load @global_a
+ // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_a>]]
+ %load_a = util.global.load @global_a : tensor<1xi32>
+ // CHECK: flow.tensor.transfer
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[]]
+ // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_b>]]
+ %arg0_b = flow.tensor.transfer %arg0 : tensor<2xi32> to #hal.device.promise<@dev_b>
+ // CHECK: flow.dispatch @dispatch
+ // CHECK-SAME{LITERAL}: stream.affinities = [#hal.device.promise<@dev_b>]
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>], [#hal.device.promise<@dev_b>]]
+ // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_b>], [#hal.device.promise<@dev_b>]]
+ %result_b:2 = flow.dispatch @dispatch(%load_a, %arg0_b) {stream.affinity = #hal.device.promise<@dev_b>} : (tensor<1xi32>, tensor<2xi32>) -> (tensor<1xi32>, tensor<2xi32>)
+ // CHECK: util.global.store
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_b>]]
+ util.global.store %result_b#0, @global_a : tensor<1xi32>
+ // CHECK: util.return
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_b>]]
+ util.return %result_b#1 : tensor<2xi32>
+}
+
+// -----
+
+// Tests that constants passed through selects are placed on the consumer.
+
+// CHECK-LABEL: @select_constants_consumed
+util.func private @select_constants_consumed(%cond: i1) -> tensor<1xi32> {
+ // CHECK: flow.tensor.constant
+ // CHECK-SAME{LITERAL}: stream.affinities = [#hal.device.promise<@dev_a>]
+ // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_a>]]
+ %cst_123 = flow.tensor.constant dense<123> : tensor<1xi32>
+ // CHECK: flow.tensor.constant
+ // CHECK-SAME{LITERAL}: stream.affinities = [#hal.device.promise<@dev_a>]
+ // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_a>]]
+ %cst_456 = flow.tensor.constant dense<456> : tensor<1xi32>
+ // CHECK: arith.select
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>], [#hal.device.promise<@dev_a>]]
+ // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_a>]]
+ %cst = arith.select %cond, %cst_123, %cst_456 : tensor<1xi32>
+ // CHECK: flow.tensor.transfer
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>]]
+ // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_a>]]
+ %cst_a = flow.tensor.transfer %cst : tensor<1xi32> to #hal.device.promise<@dev_a>
+ // CHECK: util.return
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>]]
+ util.return %cst_a : tensor<1xi32>
+}
+
+// -----
+
+// Tests that placed operands passed through selects are tracked on consumers.
+
+// CHECK-LABEL: @select_constants_placed
+util.func private @select_constants_placed(%cond: i1) -> tensor<1xi32> {
+ // CHECK: flow.tensor.constant
+ // CHECK-SAME{LITERAL}: stream.affinities = [#hal.device.promise<@dev_a>]
+ // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_a>]]
+ %cst_a = flow.tensor.constant {stream.affinity = #hal.device.promise<@dev_a>} dense<123> : tensor<1xi32>
+ // CHECK: flow.tensor.constant
+ // CHECK-SAME{LITERAL}: stream.affinities = [#hal.device.promise<@dev_b>]
+ // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_b>]]
+ %cst_b = flow.tensor.constant {stream.affinity = #hal.device.promise<@dev_b>} dense<456> : tensor<1xi32>
+ // CHECK: arith.select
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>], [#hal.device.promise<@dev_b>]]
+ // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_a>, #hal.device.promise<@dev_b>]]
+ %cst_ab = arith.select %cond, %cst_a, %cst_b : tensor<1xi32>
+ // CHECK: util.return
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>, #hal.device.promise<@dev_b>]]
+ util.return %cst_ab : tensor<1xi32>
+}
+
+// -----
+
+// Tests that a callee that does not touch an argument still tracks the
+// affinity through it.
+
+// CHECK-LABEL: @passthrough_caller
+util.func private @passthrough_caller() -> tensor<1xi32> {
+ // CHECK: flow.tensor.constant
+ // CHECK-SAME{LITERAL}: stream.affinities = [#hal.device.promise<@dev_a>]
+ // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_a>]]
+ %cst_a = flow.tensor.constant {stream.affinity = #hal.device.promise<@dev_a>} dense<123> : tensor<1xi32>
+ // CHECK: util.call @passthrough_callee
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>]]
+ // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_a>]]
+ %result_a = util.call @passthrough_callee(%cst_a) : (tensor<1xi32>) -> tensor<1xi32>
+ // CHECK: util.return
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>]]
+ util.return %result_a : tensor<1xi32>
+}
+// CHECK: util.func private @passthrough_callee
+util.func private @passthrough_callee(%arg0: tensor<1xi32>) -> tensor<1xi32> {
+ // CHECK: util.return
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>]]
+ util.return %arg0 : tensor<1xi32>
+}
+
+// -----
+
+// Tests that callees that consumer-placed arguments that are passed to callees
+// get placed based on callee usage.
+
+// CHECK-LABEL: @consumer_placement_caller
+util.func private @consumer_placement_caller() -> tensor<1xi32> {
+ // CHECK: flow.tensor.constant
+ // CHECK-SAME{LITERAL}: stream.affinities = [#hal.device.promise<@dev_a>]
+ // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_a>]]
+ %cst = flow.tensor.constant dense<123> : tensor<1xi32>
+ // CHECK: util.call @consumer_placement_callee
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>]]
+ // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_a>]]
+ %result_a = util.call @consumer_placement_callee(%cst) : (tensor<1xi32>) -> tensor<1xi32>
+ // CHECK: util.return
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>]]
+ util.return %result_a : tensor<1xi32>
+}
+// CHECK: util.func private @consumer_placement_callee
+util.func private @consumer_placement_callee(%arg0: tensor<1xi32>) -> tensor<1xi32> {
+ // CHECK: flow.tensor.transfer
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>]]
+ // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_a>]]
+ %arg0_a = flow.tensor.transfer %arg0 : tensor<1xi32> to #hal.device.promise<@dev_a>
+ // CHECK: util.return
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>]]
+ util.return %arg0_a : tensor<1xi32>
+}
+
+// -----
+
+// Tests that multiple potential affinities are propagated across call edges.
+
+// CHECK-LABEL: @select_caller
+util.func private @select_caller(%arg0: tensor<1xi32>, %cond: i1) -> tensor<1xi32> {
+ // CHECK: flow.tensor.transfer
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[]]
+ // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_a>]]
+ %arg0_a = flow.tensor.transfer %arg0 : tensor<1xi32> to #hal.device.promise<@dev_a>
+ // CHECK: util.call @select_callee
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>]]
+ // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_a>, #hal.device.promise<@dev_b>]]
+ %result_ab = util.call @select_callee(%arg0_a, %cond) : (tensor<1xi32>, i1) -> tensor<1xi32>
+ // CHECK: util.return
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>, #hal.device.promise<@dev_b>]]
+ util.return %result_ab : tensor<1xi32>
+}
+// CHECK: util.func private @select_callee
+util.func private @select_callee(%arg0_a: tensor<1xi32>, %cond: i1) -> tensor<1xi32> {
+ // CHECK: flow.tensor.constant
+ // CHECK-SAME{LITERAL}: stream.affinities = [#hal.device.promise<@dev_b>]
+ // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_b>]]
+ %cst_b = flow.tensor.constant {stream.affinity = #hal.device.promise<@dev_b>} dense<123> : tensor<1xi32>
+ // CHECK: arith.select
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>], [#hal.device.promise<@dev_b>]]
+ // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_a>, #hal.device.promise<@dev_b>]]
+ %select_ab = arith.select %cond, %arg0_a, %cst_b : tensor<1xi32>
+ // CHECK: util.return
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>, #hal.device.promise<@dev_b>]]
+ util.return %select_ab : tensor<1xi32>
+}
+
+// -----
+
+// Tests that consumer-placed ops are propagated across call edges.
+
+// CHECK-LABEL: @consumer_multi_placement_caller
+util.func private @consumer_multi_placement_caller() -> (tensor<1xi32>, tensor<1xi32>) {
+ // CHECK: flow.tensor.constant
+ // CHECK-SAME{LITERAL}: stream.affinities = [#hal.device.promise<@dev_c>]
+ // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_c>]]
+ %cst = flow.tensor.constant dense<123> : tensor<1xi32>
+ // CHECK: util.call @consumer_multi_placement_callee
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_c>]]
+ // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_c>]]
+ %result_0_c = util.call @consumer_multi_placement_callee(%cst) : (tensor<1xi32>) -> tensor<1xi32>
+ // CHECK: flow.tensor.transfer
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_c>]]
+ // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_a>]]
+ %result_0_a = flow.tensor.transfer %result_0_c : tensor<1xi32> to #hal.device.promise<@dev_a>
+ // CHECK: util.call @consumer_multi_placement_callee
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_c>]]
+ // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_c>]]
+ %result_1_c = util.call @consumer_multi_placement_callee(%cst) : (tensor<1xi32>) -> tensor<1xi32>
+ // CHECK: flow.tensor.transfer
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_c>]]
+ // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_b>]]
+ %result_1_b = flow.tensor.transfer %result_1_c : tensor<1xi32> to #hal.device.promise<@dev_b>
+ // CHECK: util.return
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>], [#hal.device.promise<@dev_b>]]
+ util.return %result_0_a, %result_1_b : tensor<1xi32>, tensor<1xi32>
+}
+// CHECK: util.func private @consumer_multi_placement_callee
+util.func private @consumer_multi_placement_callee(%arg0: tensor<1xi32>) -> tensor<1xi32> {
+ // CHECK: flow.tensor.transfer
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_c>]]
+ // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_c>]]
+ %arg0_c = flow.tensor.transfer %arg0 : tensor<1xi32> to #hal.device.promise<@dev_c>
+ // CHECK: util.return
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_c>]]
+ util.return %arg0_c : tensor<1xi32>
+}
+
+// -----
+
+// Tests that operand/result affinities are tracked across call edges.
+
+// CHECK-LABEL: @dispatch_fn_a
+util.func private @dispatch_fn_a() -> tensor<4xi32> {
+ // CHECK: flow.tensor.constant
+ // CHECK-SAME{LITERAL}: stream.affinities = [#hal.device.promise<@dev_a>]
+ // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_a>]]
+ %0 = flow.tensor.constant dense<123> : tensor<4xi32>
+ // CHECK: flow.tensor.transfer
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>]]
+ // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_a>]]
+ %1 = flow.tensor.transfer %0 : tensor<4xi32> to #hal.device.promise<@dev_a>
+ // CHECK: flow.dispatch @dispatch_a_0
+ // CHECK-SAME{LITERAL}: stream.affinities = [#hal.device.promise<@dev_a>]
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>]]
+ // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_a>]]
+ %2 = flow.dispatch @dispatch_a_0(%1) : (tensor<4xi32>) -> tensor<4xi32>
+ // CHECK: util.call @dispatch_fn_b
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>]]
+ // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_b>]]
+ %3 = util.call @dispatch_fn_b(%2) : (tensor<4xi32>) -> tensor<4xi32>
+ // CHECK: flow.tensor.transfer
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_b>]]
+ // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_a>]]
+ %4 = flow.tensor.transfer %3 : tensor<4xi32> to #hal.device.promise<@dev_a>
+ // CHECK: flow.dispatch @dispatch_a_1
+ // CHECK-SAME{LITERAL}: stream.affinities = [#hal.device.promise<@dev_a>]
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>]]
+ // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_a>]]
+ %5 = flow.dispatch @dispatch_a_1(%4) : (tensor<4xi32>) -> tensor<4xi32>
+ // CHECK: util.return
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>]]
+ util.return %5 : tensor<4xi32>
+}
+// CHECK: util.func private @dispatch_fn_b
+util.func private @dispatch_fn_b(%arg0: tensor<4xi32>) -> tensor<4xi32> {
+ // CHECK: flow.tensor.transfer
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>]]
+ // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_b>]]
+ %0 = flow.tensor.transfer %arg0 : tensor<4xi32> to #hal.device.promise<@dev_b>
+ // CHECK: flow.dispatch @dispatch_b
+ // CHECK-SAME{LITERAL}: stream.affinities = [#hal.device.promise<@dev_b>]
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_b>]]
+ // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_b>]]
+ %1 = flow.dispatch @dispatch_b(%0) : (tensor<4xi32>) -> tensor<4xi32>
+ // CHECK: util.return
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_b>]]
+ util.return %1 : tensor<4xi32>
+}
+
+// -----
+
+// Tests a realistic call graph with explicit transfers.
+
+// CHECK-LABEL: @dispatch_fn_a
+util.func private @dispatch_fn_a() -> tensor<4xi32> {
+ // CHECK: flow.tensor.constant
+ // CHECK-SAME{LITERAL}: stream.affinities = [#hal.device.promise<@dev_a>]
+ // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_a>]]
+ %0 = flow.tensor.constant dense<123> : tensor<4xi32>
+ // CHECK: flow.tensor.transfer
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>]]
+ // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_a>]]
+ %1 = flow.tensor.transfer %0 : tensor<4xi32> to #hal.device.promise<@dev_a>
+ // CHECK: util.call @dispatch_fn_b
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>]]
+ // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_b>]]
+ %2 = util.call @dispatch_fn_b(%1) : (tensor<4xi32>) -> tensor<4xi32>
+ // CHECK: util.call @dispatch_fn_c
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>]]
+ // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_c>]]
+ %3 = util.call @dispatch_fn_c(%1) : (tensor<4xi32>) -> tensor<4xi32>
+ // CHECK: flow.tensor.transfer
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_b>]]
+ // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_a>]]
+ %4 = flow.tensor.transfer %2 : tensor<4xi32> to #hal.device.promise<@dev_a>
+ // CHECK: flow.tensor.transfer
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_c>]]
+ // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_a>]]
+ %5 = flow.tensor.transfer %3 : tensor<4xi32> to #hal.device.promise<@dev_a>
+ // CHECK: flow.dispatch @dispatch_a
+ // CHECK-SAME{LITERAL}: stream.affinities = [#hal.device.promise<@dev_a>]
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>], [#hal.device.promise<@dev_a>]]
+ // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_a>]]
+ %6 = flow.dispatch @dispatch_a(%4, %5) : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32>
+ // CHECK: util.return
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>]]
+ util.return %5 : tensor<4xi32>
+}
+// CHECK: util.func private @dispatch_fn_b
+util.func private @dispatch_fn_b(%arg0: tensor<4xi32>) -> tensor<4xi32> {
+ // CHECK: flow.tensor.transfer
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>]]
+ // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_b>]]
+ %0 = flow.tensor.transfer %arg0 : tensor<4xi32> to #hal.device.promise<@dev_b>
+ // CHECK: flow.dispatch @dispatch_b
+ // CHECK-SAME{LITERAL}: stream.affinities = [#hal.device.promise<@dev_b>]
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_b>]]
+ // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_b>]]
+ %1 = flow.dispatch @dispatch_b(%0) : (tensor<4xi32>) -> tensor<4xi32>
+ // CHECK: util.return
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_b>]]
+ util.return %1 : tensor<4xi32>
+}
+// CHECK: util.func private @dispatch_fn_c
+util.func private @dispatch_fn_c(%arg0: tensor<4xi32>) -> tensor<4xi32> {
+ // CHECK: flow.tensor.transfer
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>]]
+ // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_c>]]
+ %0 = flow.tensor.transfer %arg0 : tensor<4xi32> to #hal.device.promise<@dev_c>
+ // CHECK: flow.dispatch @dispatch_c
+ // CHECK-SAME{LITERAL}: stream.affinities = [#hal.device.promise<@dev_c>]
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_c>]]
+ // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_c>]]
+ %1 = flow.dispatch @dispatch_c(%0) : (tensor<4xi32>) -> tensor<4xi32>
+ // CHECK: util.return
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_c>]]
+ util.return %1 : tensor<4xi32>
+}
+
+// -----
+
+// Tests that consumer-placed ops are tracked across branch edges.
+
+// CHECK-LABEL: @cfg_branch_constant_consumed
+util.func private @cfg_branch_constant_consumed() -> tensor<1xi32> {
+ // CHECK: flow.tensor.constant
+ // CHECK-SAME{LITERAL}: stream.affinities = [#hal.device.promise<@dev_a>]
+ // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_a>]]
+ %cst = flow.tensor.constant dense<123> : tensor<1xi32>
+ // CHECK: cf.br ^bb1
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>]]
+ cf.br ^bb1(%cst : tensor<1xi32>)
+^bb1(%bb1_arg0: tensor<1xi32>):
+ // CHECK: flow.tensor.transfer
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>]]
+ // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_a>]]
+ %cst_a = flow.tensor.transfer %bb1_arg0 : tensor<1xi32> to #hal.device.promise<@dev_a>
+ // CHECK: util.return
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>]]
+ util.return %cst_a : tensor<1xi32>
+}
+
+// -----
+
+// Tests that producer-placed ops are tracked across branch edges.
+
+// CHECK-LABEL: @cfg_branch_dispatch_produced
+util.func private @cfg_branch_dispatch_produced() -> tensor<1xi32> {
+ // CHECK: flow.tensor.constant
+ // CHECK-SAME{LITERAL}: stream.affinities = [#hal.device.promise<@dev_a>]
+ // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_a>]]
+ %cst_a = flow.tensor.constant {stream.affinity = #hal.device.promise<@dev_a>} dense<123> : tensor<1xi32>
+ // CHECK: cf.br ^bb1
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>]]
+ cf.br ^bb1(%cst_a : tensor<1xi32>)
+^bb1(%bb1_arg0: tensor<1xi32>):
+ // CHECK: flow.dispatch @dispatch
+ // CHECK-SAME{LITERAL}: stream.affinities = [#hal.device.promise<@dev_a>]
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>]]
+ // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_a>]]
+ %dispatch_a = flow.dispatch @dispatch(%bb1_arg0) : (tensor<1xi32>) -> tensor<1xi32>
+ // CHECK: util.return
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>]]
+ util.return %dispatch_a : tensor<1xi32>
+}
+
+// -----
+
+// Tests that back edges on loops track affinity changes.
+
+// CHECK-LABEL: @cfg_loop_back_edge
+util.func private @cfg_loop_back_edge() -> tensor<1xi32> {
+ // CHECK: flow.tensor.constant
+ // CHECK-SAME{LITERAL}: stream.affinities = [#hal.device.promise<@dev_a>]
+ // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_a>]]
+ %cst_a = flow.tensor.constant {stream.affinity = #hal.device.promise<@dev_a>} dense<123> : tensor<1xi32>
+ // CHECK: cf.br ^bb1
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>]]
+ cf.br ^bb1(%cst_a : tensor<1xi32>)
+^bb1(%bb1_arg0: tensor<1xi32>):
+ // CHECK: flow.tensor.transfer
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>]]
+ // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_b>]]
+ %bb1_arg0_b = flow.tensor.transfer %bb1_arg0 : tensor<1xi32> to #hal.device.promise<@dev_b>
+ // CHECK: util.call @step
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_b>]]
+ %cond = util.call @step(%bb1_arg0_b) : (tensor<1xi32>) -> i1
+ // CHECK: cf.cond_br
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>], [#hal.device.promise<@dev_b>]]
+ cf.cond_br %cond, ^bb1(%bb1_arg0 : tensor<1xi32>), ^bb2(%bb1_arg0_b : tensor<1xi32>)
+^bb2(%bb2_arg0: tensor<1xi32>):
+ // CHECK: flow.tensor.transfer
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_b>]]
+ // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_c>]]
+ %bb2_arg0_c = flow.tensor.transfer %bb2_arg0 : tensor<1xi32> to #hal.device.promise<@dev_c>
+ // CHECK: util.return
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_c>]]
+ util.return %bb2_arg0_c : tensor<1xi32>
+}
+util.func private @step(tensor<1xi32>) -> i1
+
+// -----
+
+// Tests that conditional branches acting as selects propagate both affinities.
+
+// CHECK-LABEL: @cfg_cond_branch_select
+util.func private @cfg_cond_branch_select(%cond: i1) -> tensor<1xi32> {
+ // CHECK: flow.tensor.constant
+ // CHECK-SAME{LITERAL}: stream.affinities = [#hal.device.promise<@dev_a>]
+ // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_a>]]
+ %cst_a = flow.tensor.constant {stream.affinity = #hal.device.promise<@dev_a>} dense<123> : tensor<1xi32>
+ // CHECK: flow.tensor.constant
+ // CHECK-SAME{LITERAL}: stream.affinities = [#hal.device.promise<@dev_b>]
+ // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_b>]]
+ %cst_b = flow.tensor.constant {stream.affinity = #hal.device.promise<@dev_b>} dense<456> : tensor<1xi32>
+ // CHECK: cf.cond_br
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>], [#hal.device.promise<@dev_b>]]
+ cf.cond_br %cond, ^bb1(%cst_a : tensor<1xi32>), ^bb1(%cst_b : tensor<1xi32>)
+^bb1(%bb1_arg0: tensor<1xi32>):
+ // CHECK: util.return
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>, #hal.device.promise<@dev_b>]]
+ util.return %bb1_arg0 : tensor<1xi32>
+}
+
+// -----
+
+// Tests that consumer-placed ops through conditional branches acting as selects
+// get placed on all targets.
+
+// CHECK-LABEL: @cfg_cond_branch_select_consumer
+util.func private @cfg_cond_branch_select_consumer(%cond: i1) -> tensor<1xi32> {
+ // CHECK: flow.tensor.constant
+ // CHECK-SAME{LITERAL}: stream.affinities = [#hal.device.promise<@dev_a>, #hal.device.promise<@dev_b>]
+ // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_a>, #hal.device.promise<@dev_b>]]
+ %cst = flow.tensor.constant dense<123> : tensor<1xi32>
+ // CHECK: cf.cond_br
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>, #hal.device.promise<@dev_b>], [#hal.device.promise<@dev_a>, #hal.device.promise<@dev_b>]]
+ cf.cond_br %cond, ^bb1(%cst : tensor<1xi32>), ^bb2(%cst : tensor<1xi32>)
+^bb1(%bb1_arg0: tensor<1xi32>):
+ // CHECK: flow.tensor.transfer
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>, #hal.device.promise<@dev_b>]]
+ // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_a>]]
+ %cst_a = flow.tensor.transfer %bb1_arg0 : tensor<1xi32> to #hal.device.promise<@dev_a>
+ // CHECK: util.return
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>]]
+ util.return %cst_a : tensor<1xi32>
+^bb2(%bb2_arg0: tensor<1xi32>):
+ // CHECK: flow.tensor.transfer
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>, #hal.device.promise<@dev_b>]]
+ // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_b>]]
+ %cst_b = flow.tensor.transfer %bb2_arg0 : tensor<1xi32> to #hal.device.promise<@dev_b>
+ // CHECK: util.return
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_b>]]
+ util.return %cst_b : tensor<1xi32>
+}
+
+// -----
+
+// Tests scf.if capturing consumer-placed ops tracks the affinity into nested
+// regions.
+
+// CHECK-LABEL: @scf_if_capture_consumer
+util.func private @scf_if_capture_consumer(%cond: i1) -> tensor<1xi32> {
+ // CHECK: flow.tensor.constant
+ // CHECK-SAME{LITERAL}: stream.affinities = [#hal.device.promise<@dev_a>, #hal.device.promise<@dev_b>]
+ // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_a>, #hal.device.promise<@dev_b>]]
+ %cst = flow.tensor.constant dense<123> : tensor<1xi32>
+ // CHECK: scf.if
+ %cst_ab = scf.if %cond -> tensor<1xi32> {
+ // CHECK: flow.tensor.transfer
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>, #hal.device.promise<@dev_b>]]
+ // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_a>]]
+ %cst_a = flow.tensor.transfer %cst : tensor<1xi32> to #hal.device.promise<@dev_a>
+ // CHECK: scf.yield
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>]]
+ scf.yield %cst_a : tensor<1xi32>
+ // CHECK: else
+ } else {
+ // CHECK: flow.tensor.transfer
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>, #hal.device.promise<@dev_b>]]
+ // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_b>]]
+ %cst_b = flow.tensor.transfer %cst : tensor<1xi32> to #hal.device.promise<@dev_b>
+ // CHECK: scf.yield
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_b>]]
+ scf.yield %cst_b : tensor<1xi32>
+ // CHECK{LITERAL}: } {stream.affinities.results = [[#hal.device.promise<@dev_a>, #hal.device.promise<@dev_b>]]
+ }
+ // CHECK: util.return
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>, #hal.device.promise<@dev_b>]]
+ util.return %cst_ab : tensor<1xi32>
+}
+
+// -----
+
+// Tests scf.if capturing explicitly placed ops tracks the affinity of their
+// produced results into consumers.
+
+// CHECK-LABEL: @scf_if_capture_producer
+util.func private @scf_if_capture_producer(%cond: i1) -> tensor<1xi32> {
+ // CHECK: flow.tensor.constant
+ // CHECK-SAME{LITERAL}: stream.affinities = [#hal.device.promise<@dev_a>]
+ // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_a>]]
+ %cst_a = flow.tensor.constant {stream.affinity = #hal.device.promise<@dev_a>} dense<123> : tensor<1xi32>
+ // CHECK: scf.if
+ %cst_bc = scf.if %cond -> tensor<1xi32> {
+ // CHECK: flow.tensor.transfer
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>]]
+ // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_b>]]
+ %cst_b = flow.tensor.transfer %cst_a : tensor<1xi32> to #hal.device.promise<@dev_b>
+ // CHECK: scf.yield
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_b>]]
+ scf.yield %cst_b : tensor<1xi32>
+ // CHECK: else
+ } else {
+ // CHECK: flow.tensor.transfer
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>]]
+ // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_c>]]
+ %cst_c = flow.tensor.transfer %cst_a : tensor<1xi32> to #hal.device.promise<@dev_c>
+ // CHECK: scf.yield
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_c>]]
+ scf.yield %cst_c : tensor<1xi32>
+ // CHECK{LITERAL}: } {stream.affinities.results = [[#hal.device.promise<@dev_b>, #hal.device.promise<@dev_c>]]
+ }
+ // CHECK: util.return
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_b>, #hal.device.promise<@dev_c>]]
+ util.return %cst_bc : tensor<1xi32>
+}
+
+// -----
+
+// Tests scf.if returning unassigned consumer-placed operations has the affinity
+// tracked across scf.yields and assigned based on the consumer.
+
+// CHECK-LABEL: @scf_if_consumer_yield
+util.func private @scf_if_consumer_yield(%cond: i1) -> tensor<1xi32> {
+ // CHECK: scf.if
+ %cst = scf.if %cond -> tensor<1xi32> {
+ // CHECK: flow.tensor.constant
+ // CHECK-SAME{LITERAL}: stream.affinities = [#hal.device.promise<@dev_a>]
+ // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_a>]]
+ %cst_0 = flow.tensor.constant dense<123> : tensor<1xi32>
+ // CHECK: scf.yield
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>]]
+ scf.yield %cst_0 : tensor<1xi32>
+ // CHECK: else
+ } else {
+ // CHECK: flow.tensor.constant
+ // CHECK-SAME{LITERAL}: stream.affinities = [#hal.device.promise<@dev_a>]
+ // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_a>]]
+ %cst_1 = flow.tensor.constant dense<456> : tensor<1xi32>
+ // CHECK: scf.yield
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>]]
+ scf.yield %cst_1 : tensor<1xi32>
+ // CHECK{LITERAL}: } {stream.affinities.results = [[#hal.device.promise<@dev_a>]]
+ }
+ // CHECK: flow.tensor.transfer
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>]]
+ // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_a>]]
+ %cst_a = flow.tensor.transfer %cst : tensor<1xi32> to #hal.device.promise<@dev_a>
+ // CHECK: util.return
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>]]
+ util.return %cst_a : tensor<1xi32>
+}
+
+// -----
+
+// Tests that consumer-placed ops get placed based on their use in the body.
+
+// CHECK-LABEL: @scf_for_consumer_body_transfer
+util.func private @scf_for_consumer_body_transfer() -> tensor<1xi32> {
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %c2 = arith.constant 2 : index
+ // CHECK: flow.tensor.constant
+ // CHECK-SAME{LITERAL}: stream.affinities = [#hal.device.promise<@dev_a>]
+ // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_a>]]
+ %cst = flow.tensor.constant dense<123> : tensor<1xi32>
+ // CHECK: scf.for
+ %for = scf.for %i = %c0 to %c2 step %c1 iter_args(%arg0 = %cst) -> tensor<1xi32> {
+ // CHECK: flow.tensor.transfer
+ // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_a>]]
+ %arg0_a = flow.tensor.transfer %arg0 : tensor<1xi32> to #hal.device.promise<@dev_a>
+ // CHECK: flow.dispatch @dispatch
+ // CHECK-SAME{LITERAL}: stream.affinities = [#hal.device.promise<@dev_a>]
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>]]
+ // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_a>]]
+ %t = flow.dispatch @dispatch(%arg0_a) : (tensor<1xi32>) -> tensor<1xi32>
+ // CHECK: scf.yield
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>]]
+ scf.yield %t : tensor<1xi32>
+ // CHECK{LITERAL}: } {stream.affinities.operands = [[#hal.device.promise<@dev_a>]]
+ // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_a>]]
+ }
+ // CHECK: util.return
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>]]
+ util.return %for : tensor<1xi32>
+}
+
+// -----
+
+// Tests that scf.for ops with transfers/explicit affinities on the edges get
+// the
+
+// CHECK-LABEL: @scf_for_boundary_transfer
+util.func private @scf_for_boundary_transfer() -> (tensor<1xi32>, tensor<1xi32>) {
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %c2 = arith.constant 2 : index
+ // CHECK: flow.tensor.constant
+ // CHECK-SAME{LITERAL}: stream.affinities = [#hal.device.promise<@dev_a>, #hal.device.promise<@dev_b>]
+ // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_a>, #hal.device.promise<@dev_b>]]
+ %cst = flow.tensor.constant dense<123> : tensor<1xi32>
+ // CHECK: scf.for
+ %for:2 = scf.for %i = %c0 to %c2 step %c1 iter_args(%arg0 = %cst, %arg1 = %cst) -> (tensor<1xi32>, tensor<1xi32>) {
+ // CHECK: flow.tensor.transfer
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>, #hal.device.promise<@dev_b>]]
+ // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_a>]]
+ %arg0_a = flow.tensor.transfer %arg0 : tensor<1xi32> to #hal.device.promise<@dev_a>
+ // CHECK: flow.dispatch @dispatch
+ // CHECK-SAME{LITERAL}: stream.affinities = [#hal.device.promise<@dev_a>]
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>]]
+ // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_a>]]
+ %t = flow.dispatch @dispatch(%arg0_a) : (tensor<1xi32>) -> tensor<1xi32>
+ // CHECK: scf.yield
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>], [#hal.device.promise<@dev_a>, #hal.device.promise<@dev_b>]]
+ scf.yield %t, %arg1 : tensor<1xi32>, tensor<1xi32>
+ // CHECK{LITERAL}: } {stream.affinities.operands = [[#hal.device.promise<@dev_a>, #hal.device.promise<@dev_b>], [#hal.device.promise<@dev_a>, #hal.device.promise<@dev_b>]]
+ // CHECK-SAME{LITERAL}: [[#hal.device.promise<@dev_a>], [#hal.device.promise<@dev_a>, #hal.device.promise<@dev_b>]]
+ }
+ // CHECK: flow.tensor.transfer
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>]]
+ // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_b>]]
+ %for_0_b = flow.tensor.transfer %for#0 : tensor<1xi32> to #hal.device.promise<@dev_b>
+ // CHECK: flow.tensor.transfer
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>, #hal.device.promise<@dev_b>]]
+ // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_b>]]
+ %for_1_b = flow.tensor.transfer %for#1 : tensor<1xi32> to #hal.device.promise<@dev_b>
+ // CHECK: util.return
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_b>], [#hal.device.promise<@dev_b>]]
+ util.return %for_0_b, %for_1_b : tensor<1xi32>, tensor<1xi32>
+}
+
+// -----
+
+// Tests that transfers track through iter_args.
+
+// CHECK-LABEL: @scf_for_body_transfer
+util.func private @scf_for_body_transfer() -> tensor<1xi32> {
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %c2 = arith.constant 2 : index
+ // CHECK: flow.tensor.constant
+ // CHECK-SAME{LITERAL}: stream.affinities = [#hal.device.promise<@dev_a>]
+ // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_a>]]
+ %cst_a = flow.tensor.constant {stream.affinity = #hal.device.promise<@dev_a>} dense<123> : tensor<1xi32>
+ // CHECK: scf.for
+ %for = scf.for %i = %c0 to %c2 step %c1 iter_args(%arg0 = %cst_a) -> tensor<1xi32> {
+ // CHECK: flow.tensor.transfer
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>, #hal.device.promise<@dev_b>]]
+ // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_b>]]
+ %arg0_b = flow.tensor.transfer %arg0 : tensor<1xi32> to #hal.device.promise<@dev_b>
+ // CHECK: flow.dispatch @dispatch
+ // CHECK-SAME{LITERAL}: stream.affinities = [#hal.device.promise<@dev_b>]
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_b>]]
+ // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_b>]]
+ %t = flow.dispatch @dispatch(%arg0_b) : (tensor<1xi32>) -> tensor<1xi32>
+ // CHECK: scf.yield
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_b>]]
+ scf.yield %t : tensor<1xi32>
+ // CHECK{LITERAL}: } {stream.affinities.operands = [[#hal.device.promise<@dev_a>]]
+ // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_b>]]
+ }
+ // CHECK: flow.tensor.transfer
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_b>]]
+ // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_c>]]
+ %for_c = flow.tensor.transfer %for : tensor<1xi32> to #hal.device.promise<@dev_c>
+ // CHECK: util.return
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_c>]]
+ util.return %for_c : tensor<1xi32>
+}
+
+// -----
+
+// Tests that placed values track through iter_args to consumers in scf.for
+// bodies.
+
+// CHECK-LABEL: @scf_for_capture_producer
+util.func private @scf_for_capture_producer() -> tensor<1xi32> {
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %c2 = arith.constant 2 : index
+ // CHECK: flow.tensor.constant
+ // CHECK-SAME{LITERAL}: stream.affinities = [#hal.device.promise<@dev_a>]
+ // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_a>]]
+ %cst_a = flow.tensor.constant {stream.affinity = #hal.device.promise<@dev_a>} dense<123> : tensor<1xi32>
+ // CHECK: scf.for
+ %for = scf.for %i = %c0 to %c2 step %c1 iter_args(%arg0 = %cst_a) -> tensor<1xi32> {
+ // CHECK: flow.dispatch @dispatch
+ // CHECK-SAME{LITERAL}: stream.affinities = [#hal.device.promise<@dev_a>]
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>]]
+ // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_a>]]
+ %t = flow.dispatch @dispatch(%arg0) : (tensor<1xi32>) -> tensor<1xi32>
+ // CHECK: scf.yield
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>]]
+ scf.yield %t : tensor<1xi32>
+ // CHECK{LITERAL}: } {stream.affinities.operands = [[#hal.device.promise<@dev_a>]]
+ // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_a>]]
+ }
+ // CHECK: util.return
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>]]
+ util.return %for : tensor<1xi32>
+}
+
+// -----
+
+// Tests that consumer-placed ops get placed based on their use in the body.
+
+// CHECK-LABEL: @scf_while_consumer_body_transfer
+util.func private @scf_while_consumer_body_transfer() -> tensor<1xi32> {
+ %c0 = arith.constant 0 : index
+ %c2_i32 = arith.constant 2 : i32
+ // CHECK: flow.tensor.constant
+ // CHECK-SAME{LITERAL}: stream.affinities = [#hal.device.promise<@dev_a>]
+ // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_a>]]
+ %cst = flow.tensor.constant dense<123> : tensor<1xi32>
+ // CHECK: scf.while
+ %while = scf.while(%arg0 = %cst) : (tensor<1xi32>) -> tensor<1xi32> {
+ // CHECK: flow.tensor.load
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>]]
+ %cond_i32 = flow.tensor.load %arg0[%c0] : tensor<1xi32>
+ %cond = arith.cmpi slt, %cond_i32, %c2_i32 : i32
+ // CHECK: scf.condition
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>]]
+ scf.condition(%cond) %arg0 : tensor<1xi32>
+ } do {
+ ^bb0(%arg0: tensor<1xi32>):
+ // CHECK: flow.tensor.transfer
+ // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_a>]]
+ %arg0_a = flow.tensor.transfer %arg0 : tensor<1xi32> to #hal.device.promise<@dev_a>
+ // CHECK: flow.dispatch @dispatch
+ // CHECK-SAME{LITERAL}: stream.affinities = [#hal.device.promise<@dev_a>]
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>]]
+ // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_a>]]
+ %t = flow.dispatch @dispatch(%arg0_a) : (tensor<1xi32>) -> tensor<1xi32>
+ // CHECK: scf.yield
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>]]
+ scf.yield %t : tensor<1xi32>
+ // CHECK: } attributes {
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>]]
+ // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_a>]]
+ }
+ // CHECK: util.return
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>]]
+ util.return %while : tensor<1xi32>
+}
+
+// -----
+
+// Tests that consumer-placed ops get placed based on their use as the result
+// of an scf.while body.
+
+// CHECK-LABEL: @scf_while_consumer_result_transfer
+util.func private @scf_while_consumer_result_transfer() -> tensor<1xi32> {
+ %c0 = arith.constant 0 : index
+ %c2_i32 = arith.constant 2 : i32
+ // CHECK: flow.tensor.constant
+ // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_a>]]
+ %cst = flow.tensor.constant dense<123> : tensor<1xi32>
+ // CHECK: scf.while
+ %while = scf.while(%arg0 = %cst) : (tensor<1xi32>) -> tensor<1xi32> {
+ // CHECK: flow.tensor.load
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>]]
+ %cond_i32 = flow.tensor.load %arg0[%c0] : tensor<1xi32>
+ %cond = arith.cmpi slt, %cond_i32, %c2_i32 : i32
+ // CHECK: scf.condition
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>]]
+ scf.condition(%cond) %arg0 : tensor<1xi32>
+ } do {
+ ^bb0(%arg0: tensor<1xi32>):
+ // CHECK: flow.dispatch @dispatch
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>]]
+ // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_a>]]
+ %t = flow.dispatch @dispatch(%arg0) : (tensor<1xi32>) -> tensor<1xi32>
+ // CHECK: scf.yield
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>]]
+ scf.yield %t : tensor<1xi32>
+ // CHECK: } attributes {
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>]]
+ // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_a>]]
+ }
+ // CHECK: flow.tensor.transfer
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>]]
+ // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_a>]]
+ %while_a = flow.tensor.transfer %while : tensor<1xi32> to #hal.device.promise<@dev_a>
+ // CHECK: util.return
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>]]
+ util.return %while_a : tensor<1xi32>
+}
+
+// -----
+
+// Tests that transfers track through scf.while bodies.
+
+// CHECK-LABEL: @scf_while_body_transfer
+util.func private @scf_while_body_transfer() -> tensor<1xi32> {
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %c2_i32 = arith.constant 2 : i32
+ // CHECK: flow.tensor.constant
+ // CHECK-SAME{LITERAL}: stream.affinities = [#hal.device.promise<@dev_a>]
+ // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_a>]]
+ %cst_a = flow.tensor.constant {stream.affinity = #hal.device.promise<@dev_a>} dense<123> : tensor<1xi32>
+ // CHECK: scf.while
+ %while = scf.while(%arg0 = %cst_a) : (tensor<1xi32>) -> tensor<1xi32> {
+ // CHECK: flow.tensor.load
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>, #hal.device.promise<@dev_b>]]
+ %cond_i32 = flow.tensor.load %arg0[%c0] : tensor<1xi32>
+ %cond = arith.cmpi slt, %cond_i32, %c2_i32 : i32
+ // CHECK: scf.condition
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>, #hal.device.promise<@dev_b>]]
+ scf.condition(%cond) %arg0 : tensor<1xi32>
+ } do {
+ ^bb0(%arg0: tensor<1xi32>):
+ // CHECK: flow.tensor.transfer
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>, #hal.device.promise<@dev_b>]]
+ // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_b>]]
+ %arg0_b = flow.tensor.transfer %arg0 : tensor<1xi32> to #hal.device.promise<@dev_b>
+ // CHECK: flow.dispatch @dispatch
+ // CHECK-SAME{LITERAL}: stream.affinities = [#hal.device.promise<@dev_b>]
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_b>]]
+ // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_b>]]
+ %t = flow.dispatch @dispatch(%arg0_b) : (tensor<1xi32>) -> tensor<1xi32>
+ // CHECK: scf.yield
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_b>]]
+ scf.yield %t : tensor<1xi32>
+ // CHECK: } attributes {
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>]]
+ // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_a>, #hal.device.promise<@dev_b>]]
+ }
+ // CHECK: flow.tensor.transfer
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>, #hal.device.promise<@dev_b>]]
+ // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_c>]]
+ %while_c = flow.tensor.transfer %while : tensor<1xi32> to #hal.device.promise<@dev_c>
+ // CHECK: util.return
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_c>]]
+ util.return %while_c : tensor<1xi32>
+}
+
+// -----
+
+// Tests that placed values track through to consumers in scf.while conditions.
+
+// CHECK-LABEL: @scf_while_capture_producer_condition
+util.func private @scf_while_capture_producer_condition() -> tensor<1xi32> {
+ %c0 = arith.constant 0 : index
+ %c2_i32 = arith.constant 2 : i32
+ // CHECK: flow.tensor.constant
+ // CHECK-SAME{LITERAL}: stream.affinities = [#hal.device.promise<@dev_a>]
+ // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_a>]]
+ %cst_a = flow.tensor.constant {stream.affinity = #hal.device.promise<@dev_a>} dense<123> : tensor<1xi32>
+ // CHECK: scf.while
+ %while = scf.while(%arg0 = %cst_a) : (tensor<1xi32>) -> tensor<1xi32> {
+ // CHECK: flow.tensor.transfer
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>]]
+ // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_a>]]
+ %arg0_a = flow.tensor.transfer %arg0 : tensor<1xi32> to #hal.device.promise<@dev_a>
+ // CHECK: flow.tensor.load
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>]]
+ %cond_i32 = flow.tensor.load %arg0_a[%c0] : tensor<1xi32>
+ %cond = arith.cmpi slt, %cond_i32, %c2_i32 : i32
+ // CHECK: scf.condition
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>]]
+ scf.condition(%cond) %arg0 : tensor<1xi32>
+ } do {
+ ^bb0(%arg0: tensor<1xi32>):
+ // CHECK: flow.dispatch @dispatch
+ // CHECK-SAME{LITERAL}: stream.affinities = [#hal.device.promise<@dev_a>]
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>]]
+ // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_a>]]
+ %t = flow.dispatch @dispatch(%arg0) : (tensor<1xi32>) -> tensor<1xi32>
+ // CHECK: scf.yield
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>]]
+ scf.yield %t : tensor<1xi32>
+ // CHECK: } attributes {
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>]]
+ // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_a>]]
+ }
+ // CHECK: util.return
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>]]
+ util.return %while : tensor<1xi32>
+}
+
+// -----
+
+// Tests that placed values track through to consumers in scf.while bodies.
+
+// CHECK-LABEL: @scf_while_capture_producer_body
+util.func private @scf_while_capture_producer_body() -> tensor<1xi32> {
+ %c0 = arith.constant 0 : index
+ %c2_i32 = arith.constant 2 : i32
+ // CHECK: flow.tensor.constant
+ // CHECK-SAME{LITERAL}: stream.affinities = [#hal.device.promise<@dev_a>]
+ // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_a>]]
+ %cst_a = flow.tensor.constant {stream.affinity = #hal.device.promise<@dev_a>} dense<123> : tensor<1xi32>
+ // CHECK: scf.while
+ %while = scf.while(%arg0 = %cst_a) : (tensor<1xi32>) -> tensor<1xi32> {
+ // CHECK: flow.tensor.load
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>]]
+ %cond_i32 = flow.tensor.load %arg0[%c0] : tensor<1xi32>
+ %cond = arith.cmpi slt, %cond_i32, %c2_i32 : i32
+ // CHECK: scf.condition
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>]]
+ scf.condition(%cond) %arg0 : tensor<1xi32>
+ } do {
+ ^bb0(%arg0: tensor<1xi32>):
+ // CHECK: flow.dispatch @dispatch
+ // CHECK-SAME{LITERAL}: stream.affinities = [#hal.device.promise<@dev_a>]
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>]]
+ // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_a>]]
+ %t = flow.dispatch @dispatch(%arg0) : (tensor<1xi32>) -> tensor<1xi32>
+ // CHECK: scf.yield
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>]]
+ scf.yield %t : tensor<1xi32>
+ // CHECK: } attributes {
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>]]
+ // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_a>]]
+ }
+ // CHECK: util.return
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>]]
+ util.return %while : tensor<1xi32>
+}
+
+// -----
+
+// Tests a realistic program with ABI ops.
+
+// CHECK-LABEL: @simple_program
+util.func public @simple_program(%arg0: !hal.buffer_view, %arg1: !hal.fence, %arg2: !hal.fence) -> !hal.buffer_view {
+ // CHECK: hal.tensor.import
+ // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_a>]]
+ %0 = hal.tensor.import on(#hal.device.promise<@dev_a>) wait(%arg1) => %arg0 "input0" : !hal.buffer_view -> tensor<1xi32>
+ // CHECK: util.call @_simple_program
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>]]
+ // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_b>]]
+ %1 = util.call @_simple_program(%0) : (tensor<1xi32>) -> tensor<1xi32>
+ // CHECK: flow.tensor.transfer
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_b>]]
+ // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_a>]]
+ %2 = flow.tensor.transfer %1 : tensor<1xi32> to #hal.device.promise<@dev_a>
+ // CHECK: hal.tensor.barrier
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>]]
+ // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_a>]]
+ %3 = hal.tensor.barrier join(%2 : tensor<1xi32>) => %arg2 : !hal.fence
+ // CHECK: hal.tensor.export
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>]]
+ %4 = hal.tensor.export on(#hal.device.promise<@dev_a>) %3 "output0" : tensor<1xi32> -> !hal.buffer_view
+ util.return %4 : !hal.buffer_view
+}
+// CHECK: util.func private @_simple_program
+util.func private @_simple_program(%arg0: tensor<1xi32>) -> tensor<1xi32> {
+ // CHECK: util.call @dispatch_a
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>]]
+ // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_a>]]
+ %0 = util.call @dispatch_a(%arg0) : (tensor<1xi32>) -> tensor<1xi32>
+ // CHECK: flow.tensor.transfer
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>]]
+ // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_b>]]
+ %1 = flow.tensor.transfer %0 : tensor<1xi32> to #hal.device.promise<@dev_b>
+ // CHECK: util.call @dispatch_b
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_b>]]
+ // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_b>]]
+ %2 = util.call @dispatch_b(%1) : (tensor<1xi32>) -> tensor<1xi32>
+ // CHECK: util.return
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_b>]]
+ util.return %2 : tensor<1xi32>
+}
+// CHECK: util.func private @dispatch_a
+util.func private @dispatch_a(%arg0: tensor<1xi32>) -> tensor<1xi32> {
+ // CHECK: flow.tensor.constant
+ // CHECK-SAME{LITERAL}: stream.affinities = [#hal.device.promise<@dev_a>]
+ // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_a>]]
+ %cst = flow.tensor.constant dense<[1]> : tensor<1xi32>
+ // CHECK: flow.dispatch @dispatch_a
+ // CHECK-SAME{LITERAL}: stream.affinities = [#hal.device.promise<@dev_a>]
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>], [#hal.device.promise<@dev_a>]]
+ // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_a>]]
+ %0 = flow.dispatch @dispatch_a(%arg0, %cst) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32>
+ // CHECK: util.return
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_a>]]
+ util.return %0 : tensor<1xi32>
+}
+// CHECK: util.func private @dispatch_b
+util.func private @dispatch_b(%arg0: tensor<1xi32>) -> tensor<1xi32> {
+ // CHECK: flow.tensor.constant
+ // CHECK-SAME{LITERAL}: stream.affinities = [#hal.device.promise<@dev_b>]
+ // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_b>]]
+ %cst = flow.tensor.constant dense<[2]> : tensor<1xi32>
+ // CHECK: flow.dispatch @dispatch_b
+ // CHECK-SAME{LITERAL}: stream.affinities = [#hal.device.promise<@dev_b>]
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_b>], [#hal.device.promise<@dev_b>]]
+ // CHECK-SAME{LITERAL}: stream.affinities.results = [[#hal.device.promise<@dev_b>]]
+ %0 = flow.dispatch @dispatch_b(%arg0, %cst) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32>
+ // CHECK: util.return
+ // CHECK-SAME{LITERAL}: stream.affinities.operands = [[#hal.device.promise<@dev_b>]]
+ util.return %0 : tensor<1xi32>
+}
diff --git a/compiler/src/iree/compiler/Dialect/Util/Analysis/Explorer.cpp b/compiler/src/iree/compiler/Dialect/Util/Analysis/Explorer.cpp
index e014588..bb46c83 100644
--- a/compiler/src/iree/compiler/Dialect/Util/Analysis/Explorer.cpp
+++ b/compiler/src/iree/compiler/Dialect/Util/Analysis/Explorer.cpp
@@ -679,7 +679,8 @@
// traversal algorithm separated from the policy here. This would let us
// reuse the traversal for other kinds of walks that are more specific (like
// only getting the ops or values instead of both, etc).
-TraversalResult Explorer::walkDefiningOps(Value value, ResultWalkFn fn) {
+TraversalResult Explorer::walkDefiningOps(Value value, ResultWalkFn fn,
+ TraversalBehavior options) {
// Fast-path short-circuit for constants, which are like 25% of all IR.
if (value.getDefiningOp() &&
value.getDefiningOp()->hasTrait<OpTrait::ConstantLike>()) {
@@ -856,15 +857,17 @@
// If the op is tied we may need to walk up to the operand the result is
// tied to.
- if (auto tiedOp = dyn_cast<IREE::Util::TiedOpInterface>(definingOp)) {
- auto tiedOperand = tiedOp.getTiedResultOperand(resultValue);
- if (tiedOperand) {
- LLVM_DEBUG({
- llvm::dbgs() << " + queuing tied operand ";
- tiedOperand.printAsOperand(llvm::dbgs(), asmState);
- llvm::dbgs() << "\n";
- });
- worklist.insert(tiedOperand);
+ if (!bitEnumContains(options, TraversalBehavior::DONT_WALK_TIED_VALUES)) {
+ if (auto tiedOp = dyn_cast<IREE::Util::TiedOpInterface>(definingOp)) {
+ auto tiedOperand = tiedOp.getTiedResultOperand(resultValue);
+ if (tiedOperand) {
+ LLVM_DEBUG({
+ llvm::dbgs() << " + queuing tied operand ";
+ tiedOperand.printAsOperand(llvm::dbgs(), asmState);
+ llvm::dbgs() << "\n";
+ });
+ worklist.insert(tiedOperand);
+ }
}
}
@@ -891,7 +894,8 @@
return result;
}
-TraversalResult Explorer::walkTransitiveUses(Value value, UseWalkFn fn) {
+TraversalResult Explorer::walkTransitiveUses(Value value, UseWalkFn fn,
+ TraversalBehavior options) {
LLVM_DEBUG(llvm::dbgs() << "[[ Explorer::walkTransitiveUses ]]\n");
TraversalResult result = TraversalResult::COMPLETE;
@@ -1090,15 +1094,17 @@
// If the op is tied we may need to walk down to the results the operand
// is tied to (multiple results can tie the same operand).
- if (auto tiedOp = dyn_cast<IREE::Util::TiedOpInterface>(ownerOp)) {
- for (auto tiedResult :
- tiedOp.getOperandTiedResults(use.getOperandNumber())) {
- LLVM_DEBUG({
- llvm::dbgs() << " + queuing tied result ";
- tiedResult.printAsOperand(llvm::dbgs(), asmState);
- llvm::dbgs() << "\n";
- });
- worklist.insert(tiedResult);
+ if (!bitEnumContains(options, TraversalBehavior::DONT_WALK_TIED_VALUES)) {
+ if (auto tiedOp = dyn_cast<IREE::Util::TiedOpInterface>(ownerOp)) {
+ for (auto tiedResult :
+ tiedOp.getOperandTiedResults(use.getOperandNumber())) {
+ LLVM_DEBUG({
+ llvm::dbgs() << " + queuing tied result ";
+ tiedResult.printAsOperand(llvm::dbgs(), asmState);
+ llvm::dbgs() << "\n";
+ });
+ worklist.insert(tiedResult);
+ }
}
}
@@ -1149,14 +1155,18 @@
return result;
}
-TraversalResult Explorer::walkTransitiveUsers(Value value, OperationWalkFn fn) {
+TraversalResult Explorer::walkTransitiveUsers(Value value, OperationWalkFn fn,
+ TraversalBehavior options) {
DenseSet<Operation *> visitedOwners;
- return walkTransitiveUses(value, [&](OpOperand &use) {
- if (visitedOwners.insert(use.getOwner()).second) {
- return fn(use.getOwner());
- }
- return WalkResult::advance();
- });
+ return walkTransitiveUses(
+ value,
+ [&](OpOperand &use) {
+ if (visitedOwners.insert(use.getOwner()).second) {
+ return fn(use.getOwner());
+ }
+ return WalkResult::advance();
+ },
+ options);
}
} // namespace mlir::iree_compiler
diff --git a/compiler/src/iree/compiler/Dialect/Util/Analysis/Explorer.h b/compiler/src/iree/compiler/Dialect/Util/Analysis/Explorer.h
index 1e975be..35ee12a 100644
--- a/compiler/src/iree/compiler/Dialect/Util/Analysis/Explorer.h
+++ b/compiler/src/iree/compiler/Dialect/Util/Analysis/Explorer.h
@@ -37,6 +37,31 @@
IGNORE,
};
+enum class TraversalBehavior : uint32_t {
+ // When traversing defining ops any tied result will move through its tied
+ // operand. When traversing uses any tied operand will move through its tied
+ // results (as many as are tied to the operand).
+ DEFAULT = 0u,
+ // Don't traverse through tied operands or results.
+ DONT_WALK_TIED_VALUES = 1 << 0u,
+};
+inline TraversalBehavior operator~(TraversalBehavior value) {
+ return static_cast<TraversalBehavior>(~static_cast<uint32_t>(value));
+}
+inline TraversalBehavior operator|(TraversalBehavior lhs,
+ TraversalBehavior rhs) {
+ return static_cast<TraversalBehavior>(static_cast<uint32_t>(lhs) |
+ static_cast<uint32_t>(rhs));
+}
+inline TraversalBehavior operator&(TraversalBehavior lhs,
+ TraversalBehavior rhs) {
+ return static_cast<TraversalBehavior>(static_cast<uint32_t>(lhs) &
+ static_cast<uint32_t>(rhs));
+}
+inline bool bitEnumContains(TraversalBehavior bits, TraversalBehavior bit) {
+ return (static_cast<uint32_t>(bits) & static_cast<uint32_t>(bit)) != 0;
+}
+
// Boolean operations on TraversalResult behave as though `INCOMPLETE` is
// truthy to allow for |='ing results.
enum class TraversalResult {
@@ -313,7 +338,9 @@
// Walk %2: [%2 of producer.b]
// Walk @some_user::%arg0: [%0 of producer.a]
// Walk @some_user::ret0: [%2 of producer.b]
- TraversalResult walkDefiningOps(Value value, ResultWalkFn fn);
+ TraversalResult
+ walkDefiningOps(Value value, ResultWalkFn fn,
+ TraversalBehavior options = TraversalBehavior::DEFAULT);
// Randomly walks uses of |value| and any transitive alias of |value|.
// The uses may come from any part of the program.
@@ -334,13 +361,17 @@
// Walk %arg0: [%arg0 of producer.a]
// Walk %0: [%0 of call @some_user, %arg0 of producer.b]
// Walk %2: [%2 of return, %1 of return]
- TraversalResult walkTransitiveUses(Value value, UseWalkFn fn);
+ TraversalResult
+ walkTransitiveUses(Value value, UseWalkFn fn,
+ TraversalBehavior options = TraversalBehavior::DEFAULT);
// Randomly walks uses of |value| and any transitive alias of |value| and
// returns each owner operation once. As a value may be used multiple times
// by a single operation this is equivalent to a walkTransitiveUses with
// deduplication on the owner of the use.
- TraversalResult walkTransitiveUsers(Value value, OperationWalkFn fn);
+ TraversalResult
+ walkTransitiveUsers(Value value, OperationWalkFn fn,
+ TraversalBehavior options = TraversalBehavior::DEFAULT);
private:
// Maps callee callable region -> call sites.
diff --git a/compiler/src/iree/compiler/ExternalInterfaces/StreamExternalModels.cpp b/compiler/src/iree/compiler/ExternalInterfaces/StreamExternalModels.cpp
index b82d599..ab1adf0 100644
--- a/compiler/src/iree/compiler/ExternalInterfaces/StreamExternalModels.cpp
+++ b/compiler/src/iree/compiler/ExternalInterfaces/StreamExternalModels.cpp
@@ -18,6 +18,35 @@
namespace {
+template <typename OpT>
+struct OptionalOpAffinityAttrExternalModel
+ : public IREE::Stream::AffinityOpInterface::ExternalModel<
+ OptionalOpAffinityAttrExternalModel<OpT>, OpT> {
+ static void add(MLIRContext *context) {
+ OpT::template attachInterface<OptionalOpAffinityAttrExternalModel<OpT>>(
+ *context);
+ }
+
+ // Affinity only required for results that hold resources that
+ // require placement.
+ bool requiresAffinity(Operation *op) const {
+ auto resultType = cast<OpT>(op).getResult().getType();
+ return isa<TensorType>(resultType);
+ }
+
+ IREE::Stream::AffinityAttr getAffinityAttr(Operation *op) const {
+ return op->getAttrOfType<IREE::Stream::AffinityAttr>("stream.affinity");
+ }
+
+ void setAffinityAttr(Operation *op, IREE::Stream::AffinityAttr value) const {
+ if (value) {
+ op->setAttr("stream.affinity", value);
+ } else {
+ op->removeAttr("stream.affinity");
+ }
+ }
+};
+
struct FlowTransferTargetAffinityAttrExternalModel
: public IREE::Stream::AffinityOpInterface::ExternalModel<
FlowTransferTargetAffinityAttrExternalModel,
@@ -29,11 +58,11 @@
bool requiresAffinity(Operation *op) const { return true; }
- IREE::Stream::AffinityAttr getAffinity(Operation *op) const {
+ IREE::Stream::AffinityAttr getAffinityAttr(Operation *op) const {
return op->getAttrOfType<IREE::Stream::AffinityAttr>("target");
}
- void setAffinity(Operation *op, IREE::Stream::AffinityAttr value) const {
+ void setAffinityAttr(Operation *op, IREE::Stream::AffinityAttr value) const {
op->setAttr("target", value);
}
};
@@ -49,12 +78,14 @@
bool requiresAffinity(Operation *op) const { return false; }
- IREE::Stream::AffinityAttr getAffinity(Operation *op) const {
+ bool pinsValueAffinity(Operation *op) const { return true; }
+
+ IREE::Stream::AffinityAttr getAffinityAttr(Operation *op) const {
return op->getAttrOfType<IREE::Stream::AffinityAttr>("affinity");
}
- void setAffinity(Operation *op, IREE::Stream::AffinityAttr value) const {
- if (value)
+ void setAffinityAttr(Operation *op, IREE::Stream::AffinityAttr value) const {
+ if (value) {
op->setAttr("affinity", value);
} else {
op->removeAttr("affinity");
@@ -78,12 +109,12 @@
return isa<TensorType>(globalType);
}
- IREE::Stream::AffinityAttr getAffinity(Operation *op) const {
+ IREE::Stream::AffinityAttr getAffinityAttr(Operation *op) const {
return op->getAttrOfType<IREE::Stream::AffinityAttr>("stream.affinity");
}
- void setAffinity(Operation *op, IREE::Stream::AffinityAttr value) const {
- if (value)
+ void setAffinityAttr(Operation *op, IREE::Stream::AffinityAttr value) const {
+ if (value) {
op->setAttr("stream.affinity", value);
} else {
op->removeAttr("stream.affinity");
@@ -91,7 +122,7 @@
}
};
-template <typename OpT>
+template <typename OpT, bool kRequiresAffinity = true>
struct AffinityOpAttrExternalModel
: public IREE::Stream::AffinityOpInterface::ExternalModel<
AffinityOpAttrExternalModel<OpT, kRequiresAffinity>, OpT> {
@@ -102,14 +133,14 @@
// Most structural ops don't require affinities and after placement we don't
// use the affinities even if the ops still exist.
- bool requiresAffinity(Operation *op) const { return false; }
+ bool requiresAffinity(Operation *op) const { return kRequiresAffinity; }
- IREE::Stream::AffinityAttr getAffinity(Operation *op) const {
+ IREE::Stream::AffinityAttr getAffinityAttr(Operation *op) const {
return op->getAttrOfType<IREE::Stream::AffinityAttr>("stream.affinity");
}
- void setAffinity(Operation *op, IREE::Stream::AffinityAttr value) const {
- if (value)
+ void setAffinityAttr(Operation *op, IREE::Stream::AffinityAttr value) const {
+ if (value) {
op->setAttr("stream.affinity", value);
} else {
op->removeAttr("stream.affinity");
@@ -117,32 +148,71 @@
}
};
+struct TensorAffinityTypeExternalModel
+ : public IREE::Stream::AffinityTypeInterface::ExternalModel<
+ TensorAffinityTypeExternalModel, RankedTensorType> {
+ static void add(MLIRContext *context) {
+ RankedTensorType::attachInterface<TensorAffinityTypeExternalModel>(
+ *context);
+ }
+};
+
} // namespace
void registerStreamExternalModels(DialectRegistry ®istry) {
- registry.insert<IREE::Flow::FlowDialect>();
+ registry.addExtension(+[](MLIRContext *context) {
+ TensorAffinityTypeExternalModel::add(context);
+ });
+
+ registry.insert<arith::ArithDialect>();
registry.addExtension(
- +[](MLIRContext *context, IREE::Flow::FlowDialect *dialect) {
- FlowTransferTargetAffinityAttrExternalModel::add(context);
+ +[](MLIRContext *context, arith::ArithDialect *dialect) {
+ OptionalOpAffinityAttrExternalModel<arith::ConstantOp>::add(context);
});
+ registry.insert<IREE::Flow::FlowDialect>();
+ registry.addExtension(+[](MLIRContext *context,
+ IREE::Flow::FlowDialect *dialect) {
+ FlowTransferTargetAffinityAttrExternalModel::add(context);
+ AffinityOpAttrExternalModel<IREE::Flow::DispatchRegionOp>::add(context);
+ AffinityOpAttrExternalModel<IREE::Flow::DispatchWorkgroupsOp>::add(context);
+ AffinityOpAttrExternalModel<IREE::Flow::DispatchOp>::add(context);
+ AffinityOpAttrExternalModel<IREE::Flow::CallOp>::add(context);
+ AffinityOpAttrExternalModel<IREE::Flow::TensorConstantOp>::add(context);
+ AffinityOpAttrExternalModel<IREE::Flow::TensorDynamicConstantOp>::add(
+ context);
+ AffinityOpAttrExternalModel<IREE::Flow::TensorAllocaOp>::add(context);
+ AffinityOpAttrExternalModel<IREE::Flow::TensorEmptyOp>::add(context);
+ AffinityOpAttrExternalModel<IREE::Flow::TensorSplatOp>::add(context);
+ AffinityOpAttrExternalModel<IREE::Flow::TensorCloneOp>::add(context);
+ AffinityOpAttrExternalModel<IREE::Flow::TensorSliceOp>::add(context);
+ AffinityOpAttrExternalModel<IREE::Flow::TensorUpdateOp>::add(context);
+ AffinityOpAttrExternalModel<IREE::Flow::ChannelDefaultOp>::add(context);
+ AffinityOpAttrExternalModel<IREE::Flow::CollectiveAllGatherOp>::add(
+ context);
+ AffinityOpAttrExternalModel<IREE::Flow::CollectiveAllReduceOp>::add(
+ context);
+ AffinityOpAttrExternalModel<IREE::Flow::CollectiveAllToAllOp>::add(context);
+ AffinityOpAttrExternalModel<IREE::Flow::CollectiveReduceScatterOp>::add(
+ context);
+ AffinityOpAttrExternalModel<IREE::Flow::CollectiveSendRecvOp>::add(context);
+ });
+
registry.insert<IREE::HAL::HALDialect>();
registry.addExtension(+[](MLIRContext *context,
IREE::HAL::HALDialect *dialect) {
HALTensorAffinityAttrExternalModel<IREE::HAL::TensorImportOp>::add(context);
HALTensorAffinityAttrExternalModel<IREE::HAL::TensorExportOp>::add(context);
HALTensorAffinityAttrExternalModel<IREE::HAL::TensorAliasOp>::add(context);
- HALTensorAffinityAttrExternalModel<IREE::HAL::TensorBarrierOp>::add(
- context);
});
registry.insert<IREE::Util::UtilDialect>();
- registry.addExtension(
- +[](MLIRContext *context, IREE::Util::UtilDialect *dialect) {
- GlobalOpAffinityAttrExternalModel<IREE::Util::GlobalOp>::add(context);
- AffinityOpAttrExternalModel<IREE::Util::InitializerOp>::add(context);
- AffinityOpAttrExternalModel<IREE::Util::FuncOp>::add(context);
- });
+ registry.addExtension(+[](MLIRContext *context,
+ IREE::Util::UtilDialect *dialect) {
+ GlobalOpAffinityAttrExternalModel<IREE::Util::GlobalOp>::add(context);
+ AffinityOpAttrExternalModel<IREE::Util::InitializerOp, false>::add(context);
+ AffinityOpAttrExternalModel<IREE::Util::FuncOp, false>::add(context);
+ });
}
} // namespace mlir::iree_compiler
diff --git a/compiler/src/iree/compiler/Preprocessing/Common/BUILD.bazel b/compiler/src/iree/compiler/Preprocessing/Common/BUILD.bazel
index 81e204e..13bc6bf 100644
--- a/compiler/src/iree/compiler/Preprocessing/Common/BUILD.bazel
+++ b/compiler/src/iree/compiler/Preprocessing/Common/BUILD.bazel
@@ -57,6 +57,7 @@
"//compiler/src/iree/compiler/Dialect/HAL/Analysis",
"//compiler/src/iree/compiler/Dialect/HAL/IR",
"//compiler/src/iree/compiler/Dialect/LinalgExt/IR",
+ "//compiler/src/iree/compiler/Dialect/Stream/Analysis",
"//compiler/src/iree/compiler/Dialect/Stream/IR",
"//compiler/src/iree/compiler/Dialect/Util/IR",
"@llvm-project//llvm:Support",
diff --git a/compiler/src/iree/compiler/Preprocessing/Common/CMakeLists.txt b/compiler/src/iree/compiler/Preprocessing/Common/CMakeLists.txt
index 723dacc..3764d49 100644
--- a/compiler/src/iree/compiler/Preprocessing/Common/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Preprocessing/Common/CMakeLists.txt
@@ -68,6 +68,7 @@
iree::compiler::Dialect::HAL::Analysis
iree::compiler::Dialect::HAL::IR
iree::compiler::Dialect::LinalgExt::IR
+ iree::compiler::Dialect::Stream::Analysis
iree::compiler::Dialect::Stream::IR
iree::compiler::Dialect::Util::IR
PUBLIC
diff --git a/compiler/src/iree/compiler/Preprocessing/Common/PadToIntrinsics.cpp b/compiler/src/iree/compiler/Preprocessing/Common/PadToIntrinsics.cpp
index a5e1a86..ba415b3 100644
--- a/compiler/src/iree/compiler/Preprocessing/Common/PadToIntrinsics.cpp
+++ b/compiler/src/iree/compiler/Preprocessing/Common/PadToIntrinsics.cpp
@@ -12,6 +12,7 @@
#include "iree/compiler/Codegen/Utils/GPUUtils.h"
#include "iree/compiler/Dialect/HAL/Analysis/DeviceAnalysis.h"
#include "iree/compiler/Dialect/HAL/IR/HALTypes.h"
+#include "iree/compiler/Dialect/Stream/Analysis/Affinity.h"
#include "iree/compiler/Preprocessing/Common/Passes.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallVector.h"
@@ -25,7 +26,7 @@
namespace mlir::iree_compiler::Preprocessing {
-#define GEN_PASS_DEF_PADTOINTRINSICS
+#define GEN_PASS_DEF_PADTOINTRINSICSPASS
#include "iree/compiler/Preprocessing/Common/Passes.h.inc" // IWYU pragma: export
namespace {
@@ -533,7 +534,7 @@
}
struct PadToIntrinsicsPass
- : public impl::PadToIntrinsicsBase<PadToIntrinsicsPass> {
+ : public impl::PadToIntrinsicsPassBase<PadToIntrinsicsPass> {
using Base::Base;
void runOnOperation() override;
};
@@ -544,10 +545,15 @@
MLIRContext *context = &getContext();
RewritePatternSet patterns(context);
- auto funcOp = getOperation();
- IREE::HAL::DeviceAnalysis deviceAnalysis(funcOp->getParentOp());
- if (failed(deviceAnalysis.run()))
+ auto moduleOp = getOperation();
+ IREE::Stream::AffinityAnalysis affinityAnalysis(moduleOp);
+ if (failed(affinityAnalysis.run())) {
return signalPassFailure();
+ }
+ IREE::HAL::DeviceAnalysis deviceAnalysis(moduleOp);
+ if (failed(deviceAnalysis.run())) {
+ return signalPassFailure();
+ }
bool padConvOps = padTargetType == PadTargetType::ConvOp ||
padTargetType == PadTargetType::All;
@@ -555,37 +561,46 @@
padTargetType == PadTargetType::All;
SmallVector<linalg::LinalgOp> targetConvOps;
SmallVector<linalg::LinalgOp> targetContractOps;
- funcOp.walk([&](linalg::LinalgOp linalgOp) {
- if (isa<linalg::Conv2DNhwcHwcfOp>(linalgOp.getOperation()) && padConvOps) {
- // Add convOps into worklist.
- targetConvOps.push_back(linalgOp);
- } else if (isa<linalg::BatchMatmulOp, linalg::MatmulOp,
- linalg::MatmulTransposeBOp>(linalgOp.getOperation()) &&
- padContractionOps) {
- // Add named contractionOps into worklist.
- targetContractOps.push_back(linalgOp);
- } else if (isa<linalg::GenericOp>(linalgOp.getOperation()) &&
- linalg::isaContractionOpInterface(linalgOp) &&
- padContractionOps) {
- // Add named generic contractionOps into worklist.
- targetContractOps.push_back(linalgOp);
- }
- });
+ for (auto funcOp : moduleOp.getOps<FunctionOpInterface>()) {
+ funcOp.walk([&](linalg::LinalgOp linalgOp) {
+ if (isa<linalg::Conv2DNhwcHwcfOp>(linalgOp.getOperation()) &&
+ padConvOps) {
+ targetConvOps.push_back(linalgOp);
+ } else if (isa<linalg::BatchMatmulOp, linalg::MatmulOp,
+ linalg::MatmulTransposeBOp>(linalgOp.getOperation()) &&
+ padContractionOps) {
+ targetContractOps.push_back(linalgOp);
+ } else if (isa<linalg::GenericOp>(linalgOp.getOperation()) &&
+ linalg::isaContractionOpInterface(linalgOp) &&
+ padContractionOps) {
+ targetContractOps.push_back(linalgOp);
+ }
+ });
+ }
// Iterate through and pad ops in the worklists.
+ auto getRequiredExecutableTargetAttrs = [&](Operation *op) {
+ SetVector<IREE::HAL::ExecutableTargetAttr> executableTargetAttrs;
+ SmallVector<IREE::Stream::AffinityAttr> affinityAttrs;
+ if (affinityAnalysis.tryInferExecutionAffinity(op, affinityAttrs)) {
+ for (auto affinityAttr : affinityAttrs) {
+ deviceAnalysis.gatherRequiredExecutableTargets(affinityAttr, op,
+ executableTargetAttrs);
+ }
+ }
+ return executableTargetAttrs;
+ };
IRRewriter rewriter(context);
for (auto convOp : targetConvOps) {
rewriter.setInsertionPoint(convOp);
- SetVector<IREE::HAL::ExecutableTargetAttr> executableTargets;
- deviceAnalysis.gatherRequiredExecutableTargets(convOp, executableTargets);
- padConvOp(rewriter, convOp, executableTargets.getArrayRef());
+ auto executableTargetAttrs = getRequiredExecutableTargetAttrs(convOp);
+ padConvOp(rewriter, convOp, executableTargetAttrs.getArrayRef());
}
for (auto contractOp : targetContractOps) {
rewriter.setInsertionPoint(contractOp);
- SetVector<IREE::HAL::ExecutableTargetAttr> executableTargets;
- deviceAnalysis.gatherRequiredExecutableTargets(contractOp,
- executableTargets);
- padContractionLikeOp(rewriter, contractOp, executableTargets.getArrayRef());
+ auto executableTargetAttrs = getRequiredExecutableTargetAttrs(contractOp);
+ padContractionLikeOp(rewriter, contractOp,
+ executableTargetAttrs.getArrayRef());
}
}
diff --git a/compiler/src/iree/compiler/Preprocessing/Common/Passes.td b/compiler/src/iree/compiler/Preprocessing/Common/Passes.td
index edc1705..ca29a52 100644
--- a/compiler/src/iree/compiler/Preprocessing/Common/Passes.td
+++ b/compiler/src/iree/compiler/Preprocessing/Common/Passes.td
@@ -84,8 +84,8 @@
];
}
-def PadToIntrinsics :
- InterfacePass<"iree-preprocessing-pad-to-intrinsics", "mlir::FunctionOpInterface"> {
+def PadToIntrinsicsPass :
+ Pass<"iree-preprocessing-pad-to-intrinsics", "ModuleOp"> {
let summary = "Pad linalg ops such that we can use target's intrinsics.";
let dependentDialects = [
"mlir::linalg::LinalgDialect",
diff --git a/compiler/src/iree/compiler/Preprocessing/Common/test/pad_to_intrinsics_mfma.mlir b/compiler/src/iree/compiler/Preprocessing/Common/test/pad_to_intrinsics_mfma.mlir
index 5761741..7d9da45 100644
--- a/compiler/src/iree/compiler/Preprocessing/Common/test/pad_to_intrinsics_mfma.mlir
+++ b/compiler/src/iree/compiler/Preprocessing/Common/test/pad_to_intrinsics_mfma.mlir
@@ -1,6 +1,6 @@
-// RUN: iree-opt --split-input-file %s --iree-gpu-test-target=gfx1100 --pass-pipeline="builtin.module(func.func(iree-preprocessing-pad-to-intrinsics,canonicalize))" | FileCheck %s
-// RUN: iree-opt --split-input-file %s --iree-gpu-test-target=gfx1100 --pass-pipeline="builtin.module(func.func(iree-preprocessing-pad-to-intrinsics{pad-target-type=conv},canonicalize))" | FileCheck %s -check-prefix=CONVOLUTION
-// RUN: iree-opt --split-input-file %s --iree-gpu-test-target=gfx1100 --pass-pipeline="builtin.module(func.func(iree-preprocessing-pad-to-intrinsics{pad-target-type=contraction},canonicalize))" | FileCheck %s -check-prefix=CONTRACT
+// RUN: iree-opt --split-input-file %s --iree-gpu-test-target=gfx1100 --pass-pipeline="builtin.module(iree-preprocessing-pad-to-intrinsics,func.func(canonicalize))" | FileCheck %s
+// RUN: iree-opt --split-input-file %s --iree-gpu-test-target=gfx1100 --pass-pipeline="builtin.module(iree-preprocessing-pad-to-intrinsics{pad-target-type=conv},func.func(canonicalize))" | FileCheck %s -check-prefix=CONVOLUTION
+// RUN: iree-opt --split-input-file %s --iree-gpu-test-target=gfx1100 --pass-pipeline="builtin.module(iree-preprocessing-pad-to-intrinsics{pad-target-type=contraction},func.func(canonicalize))" | FileCheck %s -check-prefix=CONTRACT
// CHECK-LABEL: func.func @main0(
diff --git a/compiler/src/iree/compiler/Preprocessing/Common/test/pad_to_intrinsics_wmma.mlir b/compiler/src/iree/compiler/Preprocessing/Common/test/pad_to_intrinsics_wmma.mlir
index f7f8328..ece0283 100644
--- a/compiler/src/iree/compiler/Preprocessing/Common/test/pad_to_intrinsics_wmma.mlir
+++ b/compiler/src/iree/compiler/Preprocessing/Common/test/pad_to_intrinsics_wmma.mlir
@@ -1,6 +1,6 @@
-// RUN: iree-opt --split-input-file %s --iree-gpu-test-target=gfx1100 --pass-pipeline="builtin.module(func.func(iree-preprocessing-pad-to-intrinsics,canonicalize))" | FileCheck %s
-// RUN: iree-opt --split-input-file %s --iree-gpu-test-target=gfx1100 --pass-pipeline="builtin.module(func.func(iree-preprocessing-pad-to-intrinsics{pad-target-type=conv},canonicalize))" | FileCheck %s -check-prefix=CONVOLUTION
-// RUN: iree-opt --split-input-file %s --iree-gpu-test-target=gfx1100 --pass-pipeline="builtin.module(func.func(iree-preprocessing-pad-to-intrinsics{pad-target-type=contraction},canonicalize))" | FileCheck %s -check-prefix=CONTRACT
+// RUN: iree-opt --split-input-file %s --iree-gpu-test-target=gfx1100 --pass-pipeline="builtin.module(iree-preprocessing-pad-to-intrinsics,func.func(canonicalize))" | FileCheck %s
+// RUN: iree-opt --split-input-file %s --iree-gpu-test-target=gfx1100 --pass-pipeline="builtin.module(iree-preprocessing-pad-to-intrinsics{pad-target-type=conv},func.func(canonicalize))" | FileCheck %s -check-prefix=CONVOLUTION
+// RUN: iree-opt --split-input-file %s --iree-gpu-test-target=gfx1100 --pass-pipeline="builtin.module(iree-preprocessing-pad-to-intrinsics{pad-target-type=contraction},func.func(canonicalize))" | FileCheck %s -check-prefix=CONTRACT
// CHECK: func.func @matmul_static(
// CHECK-SAME: %[[ARG0:.+]]: tensor<10x20xf16>,