blob: 6da8bef2ec6d79ead1074e301fdbbf19c1dc0fbc [file] [log] [blame]
// Copyright 2020 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <algorithm>
#include "iree/compiler/Dialect/Flow/Analysis/Dispatchability.h"
#include "iree/compiler/Dialect/Flow/IR/FlowOpUtils.h"
#include "iree/compiler/Dialect/Flow/IR/FlowOps.h"
#include "iree/compiler/Dialect/Flow/Transforms/DispatchConfig.h"
#include "iree/compiler/Dialect/Flow/Utils/WorkloadUtils.h"
#include "llvm/ADT/MapVector.h"
#include "llvm/Support/Debug.h"
#include "mlir/IR/BlockAndValueMapping.h"
#include "mlir/IR/Builders.h"
#include "mlir/Pass/Pass.h"
#define DEBUG_TYPE "iree-dispatch"
namespace mlir {
namespace iree_compiler {
namespace IREE {
namespace Flow {
namespace {
struct DispatchableOp {
OpDispatchPolicy::AnchorBenefit anchorBenefit;
size_t index;
Operation *op;
bool operator<(const DispatchableOp &other) const {
// Note inverted index: this is so that traversing a sorted list in
// reverse yields a topological ordering for each anchorBenefit.
return std::tie(anchorBenefit, other.index) <
std::tie(other.anchorBenefit, index);
}
};
struct DispatchRegion {
DispatchRegionOp op;
Operation *anchorOp;
Block &getEntryBlock() { return op.body().front(); }
static llvm::Optional<DispatchRegion> form(Operation *anchorOp) {
auto loc = anchorOp->getLoc();
if (anchorOp->getNumResults() < 1) {
emitError(loc) << "dispatch anchor op must have at least one result: "
<< *anchorOp;
return llvm::None;
}
Value result = anchorOp->getResult(0);
Value workload = calculateWorkload(anchorOp, result);
if (!workload) return llvm::None;
OpBuilder builder(anchorOp->getContext());
auto created =
DispatchRegionOp::formFromAnchorOp(workload, anchorOp, builder);
if (!created) return llvm::None;
return DispatchRegion{created->first, created->second};
}
// After a call to inlineDispatchOp, adds the results of the inlined op to
// the dispatch region's results and redirects any uses outside of the
// dispatch region.
void returnAndReplaceUses(Operation *origOp, Operation *inlinedOp) {
// Extend the arity of the dispatch region.
OpBuilder builder(op.getContext());
llvm::SmallVector<Value, 4> addlResults(inlinedOp->getResults());
origOp->replaceAllUsesWith(
DispatchRegionOp::appendResults(op, addlResults, builder));
}
};
// Clones and hoists any identity metadata ops from the operands and results
// of the dispatch region back out into the surrounding block.
// This function is not general purpose: it only knows how to undo sinking
// done by dispatch region formation.
void hoistDispatchRegionMetadataOps(DispatchRegion &dr,
OpDispatchPolicy &policy) {
BlockAndValueMapping mapping;
Block &block = dr.getEntryBlock();
for (unsigned i = 0, e = block.getNumArguments(); i < e; ++i) {
mapping.map(block.getArgument(i), dr.op.args()[i]);
}
// Hoist metadata ops from the operand edge.
for (auto it : llvm::enumerate(block.getArguments())) {
auto &blockArg = it.value();
for (auto &blockUse : blockArg.getUses()) {
Operation *useOp = blockUse.getOwner();
if (!policy.isIdentityMetadata(useOp) || useOp->getOperand(0) != blockArg)
continue;
OpBuilder builder(dr.op);
Operation *newOp = builder.clone(*useOp, mapping);
dr.op.argsMutable().slice(it.index(), 1).assign(newOp->getResult(0));
}
}
// Hoist metadata ops from the result edge.
// Since initial formation can only have a single block, this is safe.
auto *terminator = block.getTerminator();
for (auto it : llvm::enumerate(terminator->getOperands())) {
Operation *defOp = it.value().getDefiningOp();
if (!defOp || !policy.isIdentityMetadata(defOp)) continue;
OpBuilder builder(dr.op.getContext());
builder.setInsertionPointAfter(dr.op);
Operation *newOp = builder.clone(*defOp, mapping);
dr.op.getResult(it.index()).replaceAllUsesWith(newOp->getResult(0));
newOp->setOperand(0, dr.op.getResult(it.index()));
}
}
void findDispatchableAnchorOps(Block &block, OpDispatchPolicy &policy,
OpDispatchPolicy::AnchorBenefit maxAnchorBenefit,
llvm::SmallVectorImpl<DispatchableOp> &ops) {
for (auto it : llvm::enumerate(block.getOperations())) {
Operation *op = &it.value();
// Skip any already formed dispatch regions and non dispatchable ops.
if (isa<DispatchRegionOp>(op)) continue;
if (!policy.isDispatchable(op)) continue;
OpDispatchPolicy::AnchorBenefit anchorBenefit = policy.getAnchorBenefit(op);
if (anchorBenefit > maxAnchorBenefit || anchorBenefit <= 0) continue;
ops.push_back({anchorBenefit, it.index(), op});
}
}
// Maintains a worklist of operations that are potential fusion candidates.
// By default, items are popped in inverse topological order. An operation
// can only be added to a worklist once and later additions will be ignored.
class FusionWorklist {
public:
FusionWorklist(Block *block, bool inverseTopological = true)
: block(block), inverseTopological(inverseTopological) {}
// Adds defining ops of operands to the worklist.
void addOperandDefs(OperandRange operands) {
for (Value operand : operands) {
Operation *def = operand.getDefiningOp();
if (!def) continue;
if (def->getBlock() != block) continue;
if (!isValidItem(def)) continue;
if (!visited.insert(def).second) continue;
worklist.push_back(def);
dirty = true;
}
}
// Adds uses.
void addResultUses(ResultRange results) {
for (auto result : results) {
for (auto &use : result.getUses()) {
Operation *def = use.getOwner();
if (def->isKnownTerminator()) continue;
if (def->getBlock() != block) continue;
if (!isValidItem(def)) continue;
if (!visited.insert(def).second) continue;
worklist.push_back(def);
dirty = true;
}
}
}
// Pops the next operation or nullptr if empty.
Operation *popNext() {
if (worklist.empty()) return nullptr;
if (dirty) sort();
return worklist.pop_back_val();
}
private:
bool isValidItem(Operation *op) {
// Dispatch regions cannot be added to the worklist because they are
// modified/deleted in place and can not be guaranteed valid for the
// duration of the worklist.
return !llvm::isa<DispatchRegionOp>(op);
}
// Sorts worklist items such that popNext() values pop in inverse
// topological order.
void sort() {
if (inverseTopological) {
llvm::sort(worklist, [](Operation *left, Operation *right) {
return left->isBeforeInBlock(right);
});
} else {
llvm::sort(worklist, [](Operation *left, Operation *right) {
return right->isBeforeInBlock(left);
});
}
}
Block *block;
llvm::SmallVector<Operation *, 4> worklist;
llvm::SmallDenseSet<Operation *, 4> visited;
bool inverseTopological;
bool dirty = false;
};
LogicalResult fuseInputs(DispatchRegion &dispatchRegion,
OpDispatchPolicy &policy) {
LLVM_DEBUG(llvm::dbgs() << "++ FUSING INPUTS\n");
FusionWorklist worklist(dispatchRegion.op.getOperation()->getBlock());
worklist.addOperandDefs(dispatchRegion.op.getOperands());
while (Operation *nextOp = worklist.popNext()) {
if (!policy.isDispatchable(nextOp)) continue;
auto action = policy.fuseInput(dispatchRegion.anchorOp, nextOp);
LLVM_DEBUG(llvm::dbgs().indent(2));
if (action == OpDispatchPolicy::FusionType::MOVE_INTO) {
return nextOp->emitError() << "cannot fuse input with MOVE_INTO action";
} else if (action == OpDispatchPolicy::FusionType::DISABLED) {
LLVM_DEBUG(llvm::dbgs()
<< "- SKIP NON FUSABLE INPUT: " << *nextOp << "\n");
continue;
}
// Always inline inputs at the top of the block. Since we are processing
// the worklist in inverse topological order, this preserves the original
// ordering.
LLVM_DEBUG(llvm::dbgs() << "- FUSABLE INPUT(" << static_cast<int>(action)
<< "): " << *nextOp << "\n");
OpBuilder builder(nextOp->getContext());
auto *inlinedOp = dispatchRegion.op.inlineOp(nextOp, builder);
if (!inlinedOp) {
return failure();
}
worklist.addOperandDefs(nextOp->getOperands());
// Erase the op if it has no uses. This keeps it from forming regions
// that will be dce'd later (or getting in the way of the benefit
// scheme). Note that dispatchable ops have no side effects, which
// makes this simple check safe.
// The dispatch region must be optimized to remove unused arguments
// resulting from this fusion.
DispatchRegionOp::dceOperandsAndResults(dispatchRegion.op);
if (nextOp->use_empty()) {
nextOp->erase();
}
}
return success();
}
LogicalResult fuseOutputs(DispatchRegion &dispatchRegion,
OpDispatchPolicy &policy) {
LLVM_DEBUG(llvm::dbgs() << "++ FUSING OUTPUT\n");
FusionWorklist worklist(dispatchRegion.op.getOperation()->getBlock(),
/*inverseTopological=*/false);
worklist.addResultUses(dispatchRegion.op.getResults());
while (Operation *nextOp = worklist.popNext()) {
if (!policy.isDispatchable(nextOp)) continue;
auto action = policy.fuseOutput(dispatchRegion.anchorOp, nextOp);
LLVM_DEBUG(llvm::dbgs().indent(2));
if (action == OpDispatchPolicy::FusionType::DISABLED) {
LLVM_DEBUG(llvm::dbgs()
<< "- SKIP NON FUSABLE INPUT: " << *nextOp << "\n");
continue;
}
if (action != OpDispatchPolicy::FusionType::MOVE_INTO) {
return nextOp->emitError()
<< "cannot fuse output except with MOVE_INTO action";
}
LLVM_DEBUG(llvm::dbgs() << "- FUSABLE OUTPUT(" << static_cast<int>(action)
<< "): " << *nextOp << "\n");
// Since results will be redirected to the region results, need to scan
// for worklist items before changing use-def chain.
worklist.addResultUses(nextOp->getResults());
OpBuilder builder(nextOp->getContext());
auto *inlinedOp =
dispatchRegion.op.inlineOp(nextOp, builder, /*positionAtEnd=*/true);
if (!inlinedOp) {
return failure();
}
dispatchRegion.returnAndReplaceUses(nextOp, inlinedOp);
if (nextOp->use_empty()) {
nextOp->erase();
}
}
return success();
}
LogicalResult processBlock(Block &block, OpDispatchPolicy &policy) {
int maxAnchorBenefit =
std::numeric_limits<OpDispatchPolicy::AnchorBenefit>::max();
// Maps DispatchRegionOp to the anchor op.
llvm::DenseMap<Operation *, Operation *> dispatchRegions;
// Per iteration scratch.
llvm::SmallVector<DispatchableOp, 10> dispatchableOps;
// Loop backwards from high anchor benefit to low.
for (;;) {
dispatchableOps.clear();
// Enumerate un-dispatched ops.
findDispatchableAnchorOps(block, policy, maxAnchorBenefit, dispatchableOps);
if (dispatchableOps.empty()) break;
llvm::sort(dispatchableOps);
// Traversing from back->front will produce ops in [anchorPriority, index]
// order.
auto &d = dispatchableOps.back();
if (d.anchorBenefit <= 0) break;
LLVM_DEBUG(llvm::dbgs() << "FORM DISPATCH REGION(" << d.index << ":"
<< d.anchorBenefit << "): " << *d.op << "\n");
auto dispatchRegion = DispatchRegion::form(d.op);
if (!dispatchRegion) return failure();
dispatchRegions.insert(
std::make_pair(dispatchRegion->op, dispatchRegion->anchorOp));
// Fuse outputs prior to inputs, since they can yield more things to
// evaluate for input fusion.
if (failed(fuseOutputs(*dispatchRegion, policy))) return failure();
if (failed(fuseInputs(*dispatchRegion, policy))) return failure();
// Ensure all unused operands and results are dce'd.
DispatchRegionOp::dceOperandsAndResults(dispatchRegion->op);
hoistDispatchRegionMetadataOps(*dispatchRegion, policy);
}
return success();
}
// Identifies dispatchable ops and moves them into dispatch regions.
// Some ops, such as call, will be deferred until following passes.
class IdentifyDispatchRegions2Pass
: public PassWrapper<IdentifyDispatchRegions2Pass, FunctionPass> {
public:
void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<IREE::Flow::FlowDialect>();
}
void runOnFunction() override {
// NOTE: we require the DispatchabilityAnalysisPass to have run first.
auto dispatchability = getCachedParentAnalysis<Dispatchability>();
if (!dispatchability.hasValue()) {
getFunction().emitError()
<< "dispatchability analysis not performed "
"on module; run -iree-flow-dispatchability-analysis first";
return signalPassFailure();
}
OpDispatchPolicy policy(*dispatchability);
for (auto &block : getFunction()) {
if (failed(processBlock(block, policy))) {
return signalPassFailure();
}
}
}
};
} // namespace
std::unique_ptr<OperationPass<FuncOp>> createIdentifyDispatchRegions2Pass() {
return std::make_unique<IdentifyDispatchRegions2Pass>();
}
static PassRegistration<IdentifyDispatchRegions2Pass> pass(
"iree-flow-identify-dispatch-regions2",
"Conservatively identifies dispatch regions in functions (v2)");
} // namespace Flow
} // namespace IREE
} // namespace iree_compiler
} // namespace mlir