blob: 8381449ee63b02ea120b27f9adf50bfb7aa1811f [file] [log] [blame]
// Copyright 2021 The IREE Authors
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
#include "iree/compiler/Dialect/Stream/Analysis/ResourceUsage.h"
#include <utility>
#include "iree/compiler/Dialect/Stream/IR/StreamDialect.h"
#include "iree/compiler/Dialect/Stream/IR/StreamOps.h"
#include "iree/compiler/Dialect/Util/Analysis/DFX/Element.h"
#include "iree/compiler/Dialect/Util/Analysis/DFX/Solver.h"
#include "iree/compiler/Dialect/Util/Analysis/DFX/State.h"
#include "iree/compiler/Dialect/Util/Analysis/Explorer.h"
#include "iree/compiler/Dialect/Util/IR/UtilDialect.h"
#include "iree/compiler/Dialect/Util/IR/UtilOps.h"
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/Support/Debug.h"
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Diagnostics.h"
#define DEBUG_TYPE "iree-util-dfx"
namespace mlir {
namespace iree_compiler {
namespace IREE {
namespace Stream {
// TODO(benvanik): pick a policy for whether we want to favor copying external
// values into transients or try to reuse the external values. In very loopy
// programs enabling this lets us use a lot more stream-ordered allocations but
// an allocation hoisting pass would be able to do the same thing. Until we have
// that pass we can evaluate the difference manually with this flag. This could
// likely be solved by adding a NOT_LOOP_CARRIED bit to the usage and setting it
// on any value that ends up on a back edge of the CFG. We'd then favor those
// as transients instead of straight-line escaping results.
static constexpr bool kFavorTransients = false;
// Starts by assuming that the resource is never used and then removes assumed
// bits based on the usage in the program.
//
// BitIntegerState starts with all bits assumed so we invert the usage bits
// such that each bit indicates that some particular usage is _not_ performed.
// As the solver runs the assumed bits are removed each time the resource is
// used in a particular way (if the resource is read as part of a transfer
// operation then the NOT_TRANSFER_READ assumed bit will be removed). Upon
// completion we'll know for each resource what _is not_ performed and thus by
// inverting the bits we can arrive at what _is_ performed.
//
// Best state: never used at all (never read/written/etc).
// Worst state: used for all kinds of things.
template <typename ElementT>
class AbstractResourceUsage
: public DFX::StateWrapper<DFX::BitIntegerState<uint16_t, 4095, 0>,
ElementT> {
public:
using BaseType =
DFX::StateWrapper<DFX::BitIntegerState<uint16_t, 4095, 0>, ElementT>;
// Inverted bits matching ResourceUsageBitfield.
enum {
NOT_INDIRECT = 1u << 0,
NOT_EXTERNAL = 1u << 1,
NOT_MUTATED = 1u << 2, // beyond definition
NOT_CONSTANT = 1u << 3,
NOT_TRANSFER_READ = 1u << 4,
NOT_TRANSFER_WRITE = 1u << 5,
NOT_STAGING_READ = 1u << 6,
NOT_STAGING_WRITE = 1u << 7,
NOT_DISPATCH_READ = 1u << 8,
NOT_DISPATCH_WRITE = 1u << 9,
NOT_GLOBAL_READ = 1u << 10,
NOT_GLOBAL_WRITE = 1u << 11,
BEST_STATE = NOT_INDIRECT | NOT_EXTERNAL | NOT_MUTATED | NOT_CONSTANT |
NOT_TRANSFER_READ | NOT_TRANSFER_WRITE | NOT_STAGING_READ |
NOT_STAGING_WRITE | NOT_DISPATCH_READ | NOT_DISPATCH_WRITE |
NOT_GLOBAL_READ | NOT_GLOBAL_WRITE,
};
static_assert(BEST_STATE == BaseType::getBestState(),
"unexpected BEST_STATE value");
static bool isValidState(uint16_t bits) {
// bool isIndirect = (bits & NOT_INDIRECT) != NOT_INDIRECT;
// bool isExternal = (bits & NOT_EXTERNAL) != NOT_EXTERNAL;
bool isMutated = (bits & NOT_MUTATED) != NOT_MUTATED;
bool isConstant = (bits & NOT_CONSTANT) != NOT_CONSTANT;
// bool isTransferRead = (bits & NOT_TRANSFER_READ) != NOT_TRANSFER_READ;
// bool isTransferWrite = (bits & NOT_TRANSFER_WRITE) != NOT_TRANSFER_WRITE;
bool isStagingRead = (bits & NOT_STAGING_READ) != NOT_STAGING_READ;
bool isStagingWrite = (bits & NOT_STAGING_WRITE) != NOT_STAGING_WRITE;
bool isDispatchRead = (bits & NOT_DISPATCH_READ) != NOT_DISPATCH_READ;
bool isDispatchWrite = (bits & NOT_DISPATCH_WRITE) != NOT_DISPATCH_WRITE;
// bool isGlobalRead = (bits & NOT_GLOBAL_READ) != NOT_GLOBAL_READ;
// bool isGlobalWrite = (bits & NOT_GLOBAL_WRITE) != NOT_GLOBAL_WRITE;
// Cannot be both staging and dispatch.
if ((isStagingRead || isStagingWrite) &&
(isDispatchRead || isDispatchWrite)) {
return false;
}
// Cannot be both constant and mutated.
// TODO(benvanik): maybe allow this for initializers that perform dispatches
// to initialize the resources. This introduces copies of those that are
// annoying to elide later on.
if (isConstant && isMutated) {
return false;
}
return true;
}
ResourceUsageBitfield convertBitsToResourceUsage(uint16_t bits) const {
return static_cast<ResourceUsageBitfield>(~bits & BEST_STATE);
}
ResourceUsageBitfield getKnownUsage() const {
return convertBitsToResourceUsage(this->getKnown());
}
ResourceUsageBitfield getAssumedUsage() const {
return convertBitsToResourceUsage(this->getAssumed());
}
const std::string getAsStr() const override {
std::string str;
auto append = [&](const char *part) {
if (!str.empty()) str += '|';
str += part;
};
if (!this->isAssumed(NOT_INDIRECT)) append("indirect");
append(this->isAssumed(NOT_EXTERNAL) ? "internal" : "external");
append(this->isAssumed(NOT_MUTATED) ? "immutable" : "mutable");
if (!this->isAssumed(NOT_CONSTANT)) append("constant");
if (!this->isAssumed(NOT_TRANSFER_READ)) append("transfer_read");
if (!this->isAssumed(NOT_TRANSFER_WRITE)) append("transfer_write");
if (!this->isAssumed(NOT_STAGING_READ)) append("staging_read");
if (!this->isAssumed(NOT_STAGING_WRITE)) append("staging_write");
if (!this->isAssumed(NOT_DISPATCH_READ)) append("dispatch_read");
if (!this->isAssumed(NOT_DISPATCH_WRITE)) append("dispatch_write");
if (!this->isAssumed(NOT_GLOBAL_READ)) append("global_read");
if (!this->isAssumed(NOT_GLOBAL_WRITE)) append("global_write");
return str.empty() ? "*" : str;
}
protected:
explicit AbstractResourceUsage(const Position &pos) : BaseType(pos) {}
// Add known bits based on the static type information available.
// Doing this sets the worst case bits that analysis cannot remove.
void initializeFromType(IREE::Stream::ResourceType type) {
BaseType::intersectAssumedBits(BEST_STATE);
switch (type.getLifetime()) {
case Lifetime::Unknown:
break;
case Lifetime::External:
BaseType::intersectAssumedBits(BEST_STATE & ~NOT_EXTERNAL);
BaseType::addKnownBits(NOT_CONSTANT | NOT_STAGING_READ |
NOT_STAGING_WRITE);
break;
case Lifetime::Staging:
BaseType::intersectAssumedBits(
BEST_STATE & (~NOT_STAGING_READ | ~NOT_STAGING_WRITE |
~NOT_TRANSFER_READ | ~NOT_TRANSFER_WRITE));
BaseType::addKnownBits(NOT_EXTERNAL | NOT_CONSTANT | NOT_DISPATCH_READ |
NOT_DISPATCH_WRITE | NOT_GLOBAL_READ |
NOT_GLOBAL_WRITE);
break;
case Lifetime::Transient:
BaseType::intersectAssumedBits(
BEST_STATE & (~NOT_DISPATCH_READ | ~NOT_DISPATCH_WRITE |
~NOT_TRANSFER_READ | ~NOT_TRANSFER_WRITE));
BaseType::addKnownBits(NOT_EXTERNAL | NOT_CONSTANT | NOT_STAGING_READ |
NOT_STAGING_WRITE);
break;
case Lifetime::Variable:
BaseType::intersectAssumedBits(
BEST_STATE & (~NOT_GLOBAL_READ | ~NOT_GLOBAL_WRITE |
~NOT_TRANSFER_READ | ~NOT_TRANSFER_WRITE));
BaseType::addKnownBits(NOT_EXTERNAL | NOT_CONSTANT | NOT_STAGING_READ |
NOT_STAGING_WRITE);
break;
case Lifetime::Constant:
BaseType::intersectAssumedBits(
BEST_STATE &
(~NOT_CONSTANT | ~NOT_TRANSFER_READ | ~NOT_TRANSFER_WRITE));
BaseType::addKnownBits(NOT_MUTATED | NOT_EXTERNAL | NOT_STAGING_READ |
NOT_STAGING_WRITE);
break;
}
}
};
// Starts with the best assumed state of the value never being used for anything
// and then works towards a worst state of it being used for everything.
class ValueResourceUsage : public AbstractResourceUsage<DFX::ValueElement> {
public:
using BaseType = AbstractResourceUsage<DFX::ValueElement>;
static ValueResourceUsage &createForPosition(const Position &pos,
DFX::Solver &solver) {
return *(new (solver.getAllocator()) ValueResourceUsage(pos));
}
const std::string getName() const override { return "ValueResourceUsage"; }
const void *getID() const override { return &ID; }
static bool classof(const DFX::AbstractElement *element) {
return (element->getID() == &ID);
}
static const char ID;
private:
explicit ValueResourceUsage(const Position &pos) : BaseType(pos) {}
// Starts analysis of the |value| with known bits based on its resource type.
void initializeValue(Value value, DFX::Solver &solver) override {
auto resourceType = value.getType().cast<IREE::Stream::ResourceType>();
initializeFromType(resourceType);
}
// Updates the usage based on the op defining the value.
// This may be dynamic as the result value may be tied to an operand that
// itself is under analysis.
void updateFromDefiningOp(Value value, OpResult result, DFX::Solver &solver) {
// Some tied uses route through ops that change types - ignore those.
if (!result.getType().isa<IREE::Stream::ResourceType>()) return;
TypeSwitch<Operation *, void>(result.getOwner())
.Case([&](mlir::arith::SelectOp op) {
auto trueUsage = solver.getElementFor<ValueResourceUsage>(
*this, Position::forValue(op.getTrueValue()),
DFX::Resolution::REQUIRED);
auto falseUsage = solver.getElementFor<ValueResourceUsage>(
*this, Position::forValue(op.getFalseValue()),
DFX::Resolution::REQUIRED);
getState() ^= trueUsage.getState();
getState() ^= falseUsage.getState();
})
.Case([&](IREE::Util::DoNotOptimizeOp op) {
auto sourceUsage = solver.getElementFor<ValueResourceUsage>(
*this, Position::forValue(op.getOperand(0)),
DFX::Resolution::REQUIRED);
getState() ^= sourceUsage.getState();
})
.Case([&](IREE::Util::GlobalLoadOp op) {
removeAssumedBits(NOT_GLOBAL_READ);
auto *globalInfo =
solver.getExplorer().queryGlobalInfoFrom(op.getGlobal(), op);
auto globalType = globalInfo->op.getType()
.template cast<IREE::Stream::ResourceType>();
switch (globalType.getLifetime()) {
case IREE::Stream::Lifetime::Constant:
removeAssumedBits(NOT_CONSTANT);
break;
case IREE::Stream::Lifetime::Variable:
default:
break;
}
auto resultUsage = solver.getElementFor<ValueResourceUsage>(
*this, Position::forValue(op.getResult()),
DFX::Resolution::REQUIRED);
getState() ^= resultUsage.getState();
})
.Case([&](IREE::Util::GlobalLoadIndirectOp op) {
removeAssumedBits(NOT_INDIRECT | NOT_GLOBAL_READ);
auto resultUsage = solver.getElementFor<ValueResourceUsage>(
*this, Position::forValue(op.getResult()),
DFX::Resolution::REQUIRED);
getState() ^= resultUsage.getState();
})
.Case([&](IREE::Stream::ResourceStoreOp op) {
removeAssumedBits(NOT_STAGING_WRITE);
auto targetUsage = solver.getElementFor<ValueResourceUsage>(
*this, Position::forValue(op.getTarget()),
DFX::Resolution::REQUIRED);
getState() ^= targetUsage.getState();
})
.Case([&](IREE::Stream::TensorImportOp op) {
removeAssumedBits(NOT_MUTATED | NOT_EXTERNAL);
auto resultUsage = solver.getElementFor<ValueResourceUsage>(
*this, Position::forValue(op.getResult()),
DFX::Resolution::REQUIRED);
getState() ^= resultUsage.getState();
})
.Case([&](IREE::Stream::AsyncConstantOp op) {
removeAssumedBits(NOT_CONSTANT | NOT_TRANSFER_WRITE);
auto resultUsage = solver.getElementFor<ValueResourceUsage>(
*this, Position::forValue(op.getResult()),
DFX::Resolution::REQUIRED);
getState() ^= resultUsage.getState();
})
.Case([&](IREE::Stream::AsyncSplatOp op) {
removeAssumedBits(NOT_TRANSFER_WRITE);
auto resultUsage = solver.getElementFor<ValueResourceUsage>(
*this, Position::forValue(op.getResult()),
DFX::Resolution::REQUIRED);
getState() ^= resultUsage.getState();
})
.Case([&](IREE::Stream::AsyncCloneOp op) {
removeAssumedBits(NOT_TRANSFER_WRITE);
auto sourceUsage = solver.getElementFor<ValueResourceUsage>(
*this, Position::forValue(op.getSource()),
DFX::Resolution::OPTIONAL);
getState() ^= sourceUsage.getState();
})
.Case([&](IREE::Stream::AsyncSliceOp op) {
removeAssumedBits(NOT_TRANSFER_WRITE);
auto sourceUsage = solver.getElementFor<ValueResourceUsage>(
*this, Position::forValue(op.getSource()),
DFX::Resolution::OPTIONAL);
getState() ^= sourceUsage.getState();
})
.Case([&](IREE::Stream::AsyncFillOp op) {
removeAssumedBits(NOT_TRANSFER_WRITE);
auto targetUsage = solver.getElementFor<ValueResourceUsage>(
*this, Position::forValue(op.getTarget()),
DFX::Resolution::REQUIRED);
getState() ^= targetUsage.getState();
})
.Case([&](IREE::Stream::AsyncUpdateOp op) {
removeAssumedBits(NOT_TRANSFER_WRITE);
auto targetUsage = solver.getElementFor<ValueResourceUsage>(
*this, Position::forValue(op.getTarget()),
DFX::Resolution::REQUIRED);
getState() ^= targetUsage.getState();
})
.Case([&](IREE::Stream::AsyncCopyOp op) {
removeAssumedBits(NOT_TRANSFER_WRITE);
auto targetUsage = solver.getElementFor<ValueResourceUsage>(
*this, Position::forValue(op.getTarget()),
DFX::Resolution::REQUIRED);
getState() ^= targetUsage.getState();
})
.Case([&](IREE::Stream::AsyncTransferOp op) {
removeAssumedBits(NOT_TRANSFER_WRITE);
auto sourceUsage = solver.getElementFor<ValueResourceUsage>(
*this, Position::forValue(op.getSource()),
DFX::Resolution::OPTIONAL);
bool isSourceStaging = !(sourceUsage.isAssumed(NOT_STAGING_READ) &&
sourceUsage.isAssumed(NOT_STAGING_WRITE));
bool isTargetStaging =
!(isAssumed(NOT_STAGING_READ) && isAssumed(NOT_STAGING_WRITE));
if (isSourceStaging != isTargetStaging) {
// Can't transition staging across transfers.
LLVM_DEBUG({
llvm::dbgs() << "[ValueResourceUsage] skipping transfer source: ";
op.print(llvm::dbgs(), solver.getAsmState());
llvm::dbgs() << "\n";
});
return;
}
// TODO(benvanik): remove kFavorTransients.
bool isSourceExternal = !sourceUsage.isAssumed(NOT_EXTERNAL);
bool isTargetInternal = isAssumed(NOT_EXTERNAL);
if (kFavorTransients && isSourceExternal && isTargetInternal) {
LLVM_DEBUG({
llvm::dbgs() << "[ValueResourceUsage] skipping forward prop of "
"external into internal:";
op.print(llvm::dbgs(), solver.getAsmState());
llvm::dbgs() << "\n";
});
return;
}
getState() ^= sourceUsage.getState();
})
.Case([&](IREE::Stream::AsyncStoreOp op) {
removeAssumedBits(NOT_STAGING_WRITE);
auto targetUsage = solver.getElementFor<ValueResourceUsage>(
*this, Position::forValue(op.getTarget()),
DFX::Resolution::REQUIRED);
getState() ^= targetUsage.getState();
})
.Case([&](IREE::Stream::AsyncDispatchOp op) {
removeAssumedBits(NOT_DISPATCH_WRITE);
auto tiedOperand = op.getTiedResultOperand(result);
if (tiedOperand) {
auto tiedUsage = solver.getElementFor<ValueResourceUsage>(
*this, Position::forValue(tiedOperand),
DFX::Resolution::REQUIRED);
getState() ^= tiedUsage.getState();
} else {
auto resultUsage = solver.getElementFor<ValueResourceUsage>(
*this, Position::forValue(result), DFX::Resolution::REQUIRED);
getState() ^= resultUsage.getState();
}
})
.Default([&](Operation *op) {});
}
// Updates the usage based on the particular usage as |operand|.
// This walks through tied uses as well.
void updateFromUse(Value value, OpOperand &operand, DFX::Solver &solver) {
// Some tied uses route through ops that change types - ignore those.
if (!operand.get().getType().isa<IREE::Stream::ResourceType>()) return;
auto *userOp = operand.getOwner();
unsigned operandIdx = operand.getOperandNumber();
TypeSwitch<Operation *, void>(userOp)
.Case([&](mlir::arith::SelectOp op) {
auto resultUsage = solver.getElementFor<ValueResourceUsage>(
*this, Position::forValue(op.getResult()),
DFX::Resolution::REQUIRED);
getState() ^= resultUsage.getState();
})
.Case([&](mlir::BranchOpInterface op) {
auto operandUsage = solver.getElementFor<ValueResourceUsage>(
*this, Position::forValue(op->getOperand(operandIdx)),
DFX::Resolution::REQUIRED);
getState() ^= operandUsage.getState();
solver.getExplorer().walkOutgoingBranchOperandArguments(
op, operandIdx, [&](Block *targetBlock, BlockArgument arg) {
auto argUsage = solver.getElementFor<ValueResourceUsage>(
*this, Position::forValue(arg), DFX::Resolution::OPTIONAL);
getState() ^= argUsage;
return WalkResult::advance();
});
})
.Case([&](mlir::func::ReturnOp op) {
auto operandUsage = solver.getElementFor<ValueResourceUsage>(
*this, Position::forValue(op.getOperand(operandIdx)),
DFX::Resolution::REQUIRED);
getState() ^= operandUsage.getState();
solver.getExplorer().walkIncomingCalls(
op->getParentOfType<mlir::CallableOpInterface>(),
[&](mlir::CallOpInterface callOp) {
auto argUsage = solver.getElementFor<ValueResourceUsage>(
*this, Position::forValue(callOp->getResult(operandIdx)),
DFX::Resolution::OPTIONAL);
getState() ^= argUsage;
return WalkResult::advance();
});
})
.Case([&](IREE::Util::DoNotOptimizeOp op) {
auto resultUsage = solver.getElementFor<ValueResourceUsage>(
*this, Position::forValue(op.getResult(0)),
DFX::Resolution::REQUIRED);
getState() ^= resultUsage.getState();
})
.Case([&](IREE::Util::GlobalStoreOp op) {
removeAssumedBits(NOT_GLOBAL_WRITE);
auto *globalInfo =
solver.getExplorer().queryGlobalInfoFrom(op.getGlobal(), op);
auto globalType = globalInfo->op.getType()
.template cast<IREE::Stream::ResourceType>();
switch (globalType.getLifetime()) {
case IREE::Stream::Lifetime::Constant:
removeAssumedBits(NOT_CONSTANT);
break;
case IREE::Stream::Lifetime::Variable:
default:
break;
}
})
.Case([&](IREE::Util::GlobalStoreIndirectOp op) {
removeAssumedBits(NOT_INDIRECT | NOT_GLOBAL_WRITE);
})
.Case([&](IREE::Stream::TensorExportOp op) {
removeAssumedBits(NOT_MUTATED | NOT_EXTERNAL);
})
.Case([&](IREE::Stream::AsyncCloneOp op) {
removeAssumedBits(NOT_TRANSFER_READ);
auto resultUsage = solver.getElementFor<ValueResourceUsage>(
*this, Position::forValue(op.getResult()),
DFX::Resolution::OPTIONAL);
getState() ^= resultUsage.getState();
})
.Case([&](IREE::Stream::AsyncSliceOp op) {
removeAssumedBits(NOT_TRANSFER_READ);
auto resultUsage = solver.getElementFor<ValueResourceUsage>(
*this, Position::forValue(op.getResult()),
DFX::Resolution::OPTIONAL);
getState() ^= resultUsage.getState();
})
.Case([&](IREE::Stream::AsyncFillOp op) {
removeAssumedBits(NOT_MUTATED | NOT_TRANSFER_WRITE);
auto resultUsage = solver.getElementFor<ValueResourceUsage>(
*this, Position::forValue(op.getResult()),
DFX::Resolution::REQUIRED);
getState() ^= resultUsage.getState();
})
.Case([&](IREE::Stream::AsyncUpdateOp op) {
if (value == op.getUpdate()) {
removeAssumedBits(NOT_TRANSFER_READ);
} else {
removeAssumedBits(NOT_MUTATED | NOT_TRANSFER_WRITE);
auto resultUsage = solver.getElementFor<ValueResourceUsage>(
*this, Position::forValue(op.getResult()),
DFX::Resolution::REQUIRED);
getState() ^= resultUsage.getState();
}
})
.Case([&](IREE::Stream::AsyncCopyOp op) {
if (value == op.getSource()) {
removeAssumedBits(NOT_TRANSFER_READ);
} else {
removeAssumedBits(NOT_MUTATED | NOT_TRANSFER_WRITE);
auto resultUsage = solver.getElementFor<ValueResourceUsage>(
*this, Position::forValue(op.getResult()),
DFX::Resolution::REQUIRED);
getState() ^= resultUsage.getState();
}
})
.Case([&](IREE::Stream::AsyncTransferOp op) {
removeAssumedBits(NOT_TRANSFER_READ);
auto resultUsage = solver.getElementFor<ValueResourceUsage>(
*this, Position::forValue(op.getResult()),
DFX::Resolution::OPTIONAL);
bool isSourceStaging =
!(isAssumed(NOT_STAGING_READ) && isAssumed(NOT_STAGING_WRITE));
bool isTargetStaging = !(resultUsage.isAssumed(NOT_STAGING_READ) &&
resultUsage.isAssumed(NOT_STAGING_WRITE));
if (isSourceStaging != isTargetStaging) {
// Can't transition staging across transfers.
LLVM_DEBUG({
llvm::dbgs() << "[ValueResourceUsage] skipping transfer target: ";
op.print(llvm::dbgs(), solver.getAsmState());
llvm::dbgs() << "\n";
});
return;
}
// TODO(benvanik): remove kFavorTransients.
bool isSourceInternal = isAssumed(NOT_EXTERNAL);
bool isTargetExternal = !resultUsage.isAssumed(NOT_EXTERNAL);
if (kFavorTransients && isSourceInternal && isTargetExternal) {
LLVM_DEBUG({
llvm::dbgs() << "[ValueResourceUsage] skipping back prop of "
"external into internal due to kFavorTransients:";
op.print(llvm::dbgs(), solver.getAsmState());
llvm::dbgs() << "\n";
});
return;
}
getState() ^= resultUsage.getState();
})
.Case([&](IREE::Stream::AsyncLoadOp op) {
removeAssumedBits(NOT_STAGING_READ);
})
.Case([&](IREE::Stream::AsyncStoreOp op) {
removeAssumedBits(NOT_MUTATED | NOT_STAGING_WRITE);
auto resultUsage = solver.getElementFor<ValueResourceUsage>(
*this, Position::forValue(op.getResult()),
DFX::Resolution::REQUIRED);
getState() ^= resultUsage.getState();
})
.Case([&](IREE::Stream::AsyncDispatchOp op) {
removeAssumedBits(NOT_DISPATCH_READ);
for (auto result : op.getOperandTiedResults(operandIdx)) {
removeAssumedBits(NOT_MUTATED | NOT_DISPATCH_WRITE);
auto resultUsage = solver.getElementFor<ValueResourceUsage>(
*this, Position::forValue(result), DFX::Resolution::REQUIRED);
getState() ^= resultUsage.getState();
}
})
.Default([&](Operation *op) {});
}
// Updates the usage state of |value| by walking all defining ops (up through
// function arguments, branch arguments, and tied results) and all transitive
// uses (down through function calls, branches, and tied operands).
ChangeStatus updateValue(Value value, DFX::Solver &solver) override {
auto assumedBits = getAssumed();
auto traversalResult = TraversalResult::COMPLETE;
// Join with defining ops - of which there may be multiple if we come from
// a branch/region argument or call result.
traversalResult |=
solver.getExplorer().walkDefiningOps(value, [&](OpResult result) {
updateFromDefiningOp(value, result, solver);
return WalkResult::advance();
});
// Join with using ops.
traversalResult |=
solver.getExplorer().walkTransitiveUses(value, [&](OpOperand &operand) {
updateFromUse(value, operand, solver);
return WalkResult::advance();
});
if (traversalResult == TraversalResult::INCOMPLETE) {
removeAssumedBits(NOT_EXTERNAL);
}
// Filter out impossible states by marking the state invalid.
// The fixpoint framework will try again.
if (!isValidState(assumedBits)) {
return indicatePessimisticFixpoint();
}
return assumedBits == getAssumed() ? ChangeStatus::UNCHANGED
: ChangeStatus::CHANGED;
}
friend class DFX::Solver;
};
const char ValueResourceUsage::ID = 0;
ResourceUsageAnalysis::ResourceUsageAnalysis(Operation *rootOp)
: explorer(rootOp, TraversalAction::SHALLOW), solver(explorer, allocator) {
explorer.setOpAction<IREE::Util::InitializerOp>(TraversalAction::RECURSE);
explorer.setOpAction<mlir::func::FuncOp>(TraversalAction::RECURSE);
explorer.setDialectAction<IREE::Stream::StreamDialect>(
TraversalAction::RECURSE);
// Ignore the contents of executables (linalg goo, etc).
explorer.setOpAction<IREE::Stream::ExecutableOp>(TraversalAction::IGNORE);
explorer.initialize();
}
ResourceUsageAnalysis::~ResourceUsageAnalysis() = default;
llvm::Optional<ResourceUsageBitfield>
ResourceUsageAnalysis::tryLookupResourceUsage(Value value) {
auto resourceUsage =
solver.lookupElementFor<ValueResourceUsage>(Position::forValue(value));
if (!resourceUsage) return llvm::None;
return resourceUsage->getAssumedUsage();
}
LogicalResult ResourceUsageAnalysis::run() {
// TODO(benvanik): initialize globals and track usage through them.
// Today we pin globals to <constant> or <variable> but it'd be nice to
// set that based on actual usage here.
//
// Initialize globals that we need to resolve.
// explorer.forEachGlobal([&](const auto *globalInfo) {
// auto globalType = globalInfo->op.type();
// if (globalType.template isa<IREE::Stream::ResourceType>()) {
// solver.getOrCreateElementFor<GlobalResourceUsage>(
// Position::forOperation(globalInfo->op));
// }
// });
// Initialize all SSA values we can do just with trivial search.
explorer.walkValues([&](Value value) {
if (value.getType().isa<IREE::Stream::ResourceType>()) {
solver.getOrCreateElementFor<ValueResourceUsage>(
Position::forValue(value));
}
return WalkResult::advance();
});
return solver.run();
}
} // namespace Stream
} // namespace IREE
} // namespace iree_compiler
} // namespace mlir