blob: 54ce60e10eb4e8317b46854a3165dd7e9f05531f [file] [log] [blame]
// Copyright 2019 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <algorithm>
#include "compiler/IR/Ops.h"
#include "compiler/Utils/DispatchUtils.h"
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/DenseSet.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SetVector.h"
#include "llvm/ADT/SmallVector.h"
#include "mlir/Dialect/StandardOps/Ops.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/BlockAndValueMapping.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/Location.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/StandardTypes.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Pass/PassRegistry.h"
#include "mlir/Support/LLVM.h"
#include "mlir/Support/LogicalResult.h"
#include "mlir/Transforms/Utils.h"
#include "tensorflow/compiler/mlir/xla/ir/hlo_ops.h"
namespace mlir {
namespace iree_compiler {
namespace {
// Returns true if the given |op| can be dispatched in all cases.
// Other passes may handle special cases of these ops but this initial
// identification is conservative.
bool isDispatchableOp(Operation *op) {
if (op->getDialect() && op->getDialect()->getNamespace().startswith("iree")) {
// Ignore things we've already produced as they should only relate to
// sequencer operations.
return false;
} else if (op->isKnownTerminator()) {
// Currently we skip all terminators as we want to leave them in the block
// to keep it valid. Future folding passes may take care of them if they are
// worth bringing into the dispatch region.
return false;
} else if (isa<CallOp>(op)) {
// This may be handled by a control-flow folding pass later once we have
// done our initial analysis and know what functions are compatible.
return false;
} else if (isa<CallIndirectOp>(op)) {
// Indirect calls are not supported in dispatch code.
return false;
} else if (isa<AllocOp>(op)) {
// Allocations are sequencer ops.
// Note that we could support static allocations (convert to stack/etc).
return false;
} else if (isa<ConstantOp>(op)) {
// Constants are handled in the RematerializeDispatchConstants pass.
// We do that independently so that we can more easily see the use of
// constants across all dispatches instead of just on an individual basis
// as we do here.
return false;
} else if (isa<xla_hlo::DynamicUpdateSliceOp>(op)) {
// TODO(benvanik): lower these to the sequencer dialect prior to ID'ing.
return false;
}
return true;
}
// Returns true if the given |op| can have other ops fused into it.
// This is sketchy and it'd be nice to define this as an op property instead.
//
// What we are looking for in foldable ops is whether the execution of the op
// when fused has some possible benefit (or at least, a non-negative cost).
// Eventually we want to allow backends to vote on this and allow multiple
// folding strategies within the same executable. For now we just hardcode what
// we know for the ops we have.
//
// Preconditions: isDispatchableOp(op) == true.
bool isFusionRootOp(Operation *op) {
if (isa<xla_hlo::DotOp>(op) || isa<xla_hlo::ConvOp>(op)) {
// We have hand-written kernels for these right now we want to stand alone.
// When we do a bit more magic we should allow these ops to fold.
return false;
}
return true;
}
// Returns true if the given |op| can be fused into other ops.
//
// Ops that perform narrowing on shapes (such as reduction ops) should not
// generally be fused with other downstream ops (probably...). This avoids
// potential oversampling and indexing issues and allows backends to perform
// more efficient rooted cascading reduction dispatches.
//
// Preconditions: isDispatchableOp(op) == true.
bool isFusableOp(Operation *op) {
if (isa<xla_hlo::DotOp>(op) || isa<xla_hlo::ConvOp>(op)) {
return false;
} else if (isa<xla_hlo::ReduceOp>(op)) {
// Reduction is usually a dedicated root operation - we can shove things in
// the front of it but not behind.
return false;
}
return true;
}
// Puts all of the |unsortedOps| into |sortedOps| in an arbitrary topological
// order.
// https://en.wikipedia.org/wiki/Topological_sorting#Depth-first_search
//
// Preconditions: |unsortedOps| has no cycles within the set of ops.
std::vector<Operation *> sortOpsTopologically(
const llvm::SetVector<Operation *> &unsortedOps) {
llvm::SetVector<Operation *> unmarkedOps;
unmarkedOps.insert(unsortedOps.begin(), unsortedOps.end());
llvm::SetVector<Operation *> markedOps;
using VisitFn = std::function<void(Operation * op)>;
VisitFn visit = [&](Operation *op) {
if (markedOps.count(op) > 0) return;
for (auto *result : op->getResults()) {
for (auto *user : result->getUsers()) {
// Don't visit ops not in our set.
if (unsortedOps.count(user) == 0) continue;
visit(user);
}
}
markedOps.insert(op);
};
while (!unmarkedOps.empty()) {
auto *op = unmarkedOps.pop_back_val();
visit(op);
}
auto sortedOps = markedOps.takeVector();
std::reverse(sortedOps.begin(), sortedOps.end());
return sortedOps;
}
// Recursively traverses the IR DAG along the operand edges to find ops we are
// able to fuse and appends them to |subgraph|.
void gatherFusionOps(Operation *op, llvm::SetVector<Operation *> *subgraph) {
// Skip ops that are used outside of the subgraph we are building.
for (auto *result : op->getResults()) {
if (result->use_empty() || result->hasOneUse()) continue;
for (auto *user : result->getUsers()) {
if (subgraph->count(user) == 0) {
// Op that consumes the result is not (yet) in the subgraph.
// For now we'll ignore these as it may represent a fork that we don't
// want to join too early.
return;
}
}
}
// Walk backward up to ops providing our input operands.
for (auto *operand : op->getOperands()) {
auto *sourceOp = operand->getDefiningOp();
if (!sourceOp) continue;
if (subgraph->count(sourceOp) == 0) {
if (isDispatchableOp(sourceOp) && isFusableOp(sourceOp)) {
gatherFusionOps(sourceOp, subgraph);
}
}
}
subgraph->insert(op);
}
// Finds all ops that can be fused together with the given |rootOp| by searching
// backwards in the op order through input edges.
// Returns a topologically sorted list of all fused ops with |rootOp| at the
// end.
std::vector<Operation *> findFusionSubgraphFromRoot(Operation *rootOp) {
if (!isFusionRootOp(rootOp)) {
return {rootOp};
}
llvm::SetVector<Operation *> subgraph;
subgraph.insert(rootOp);
gatherFusionOps(rootOp, &subgraph);
return sortOpsTopologically(subgraph);
}
// Identifies ranges of dispatchable ops and moves them into dispatch regions.
LogicalResult identifyBlockDispatchRegions(FuncOp func, Block *block) {
// Fixed point iteration until we can no longer fuse anything.
bool didFindAnyNewRegions;
do {
// Iterate in reverse so we root further along in the op list.
didFindAnyNewRegions = false;
for (auto &rootOp : llvm::reverse(*block)) {
if (!isDispatchableOp(&rootOp)) {
// Op should remain at the sequencer level.
continue;
}
// Attempt to find all operations, including rootOp, that can be fused.
// The ops will be sorted in topological order with rootOp as the last op.
// Worst case we may end up with a subgraph of only the rootOp.
auto fusedSubgraph = findFusionSubgraphFromRoot(&rootOp);
// Compute the workload based on the output shape.
// When variadic all output shapes match so we can just take the first.
auto *workload = calculateWorkload(&rootOp, rootOp.getResult(0));
// Try to build a dispatch region from this root.
if (failed(buildDispatchRegion(func, block, workload, fusedSubgraph))) {
return failure();
}
// Successfully created a dispatch region from the ops and we must now
// start over again as we've likely trashed the whole block structure.
didFindAnyNewRegions = true;
break;
}
} while (didFindAnyNewRegions);
return success();
}
} // namespace
// Identifies dispatchable ops and moves them into iree.dispatch_regions.
// Some ops, such as call, will be deferred until following passes.
class IdentifyDispatchRegionsPass
: public FunctionPass<IdentifyDispatchRegionsPass> {
public:
void runOnFunction() override {
auto func = getFunction();
for (auto &block : func) {
if (failed(identifyBlockDispatchRegions(func, &block))) {
return signalPassFailure();
}
}
}
};
std::unique_ptr<OpPassBase<FuncOp>> createIdentifyDispatchRegionsPass() {
return std::make_unique<IdentifyDispatchRegionsPass>();
}
static PassRegistration<IdentifyDispatchRegionsPass> pass(
"iree-identify-dispatch-regions",
"Conservatively identifies dispatch regions in functions.");
} // namespace iree_compiler
} // namespace mlir