blob: a6640a5cad590f81414a42d91cb6b56e148867fe [file] [log] [blame]
// Copyright 2019 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "iree/compiler/Dialect/Flow/Analysis/Dispatchability.h"
#include <list>
#include "iree/compiler/Dialect/Flow/Utils/DispatchUtils.h"
#include "llvm/ADT/SetVector.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/SymbolTable.h"
#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
namespace mlir {
namespace iree_compiler {
// static
LogicalResult Dispatchability::annotateIR(ModuleOp moduleOp) {
Dispatchability dispatchability;
if (failed(dispatchability.recalculate(moduleOp))) {
moduleOp.emitOpError()
<< "failed to analyze dispatchability for the module";
return failure();
}
Builder builder(moduleOp.getContext());
SymbolTable symbolTable(moduleOp);
for (auto &funcDispatchability : dispatchability.funcDispatchability_) {
auto funcOp = symbolTable.lookup<FuncOp>(funcDispatchability.first);
funcOp.setAttr("dispatchable",
builder.getBoolAttr(funcDispatchability.second));
}
return success();
}
LogicalResult Dispatchability::recalculate(ModuleOp moduleOp) {
funcDispatchability_.clear();
funcCloneModuleOp_ = ModuleOp::create(UnknownLoc::get(moduleOp.getContext()));
funcClones_.clear();
// Run through all functions until we are able to compute their
// dispatchability. We do this so that we can determine if calls are allowed.
OpBuilder cloneBuilder(funcCloneModuleOp_.get());
std::vector<FuncOp> nextWorklist(moduleOp.getOps<FuncOp>().begin(),
moduleOp.getOps<FuncOp>().end());
std::vector<FuncOp> worklist;
bool anyChanged;
do {
anyChanged = false;
worklist.swap(nextWorklist);
nextWorklist.clear();
for (auto funcOp : worklist) {
auto isDispatchable = computeDispatchability(funcOp);
if (isDispatchable.hasValue()) {
funcDispatchability_[funcOp.getName()] = isDispatchable.getValue();
if (isDispatchable.getValue()) {
auto clonedFuncOp = cast<FuncOp>(cloneBuilder.clone(*funcOp));
funcClones_[funcOp.getName()] = clonedFuncOp;
funcCloneModuleOp_->push_back(clonedFuncOp);
}
anyChanged = true;
} else {
nextWorklist.push_back(funcOp);
}
}
} while (anyChanged);
if (!nextWorklist.empty()) {
return moduleOp.emitError() << "cycle detected in dispatchability analysis";
}
return success();
}
Optional<bool> Dispatchability::computeDispatchability(FuncOp funcOp) {
if (funcOp.isExternal()) {
// We assume all imports have side-effects right now, but that may not be
// the case. We should add an attribute and check for it.
return false;
}
// TODO(b/144530470): replace with tablegen attributes/interfaces.
for (auto &block : funcOp.getBlocks()) {
for (auto &op : block.getOperations()) {
if (!IREE::Flow::isOpOfKnownDialect(&op)) {
// Custom dialects aren't dispatchable (yet).
return false;
} else if (auto callOp = dyn_cast<CallOp>(op)) {
if (callOp.getCallee() == funcOp.getName()) {
// Recursion.
continue;
}
auto it = funcDispatchability_.find(callOp.callee());
if (it == funcDispatchability_.end()) {
// Not yet calculated - yield.
return llvm::None;
}
return it->second;
} else if (isa<CallIndirectOp>(op)) {
// Indirect calls are not supported and must first be devirtualized.
return false;
} else if (isa<mlir::ReturnOp>(op)) {
// TODO(benvanik): widen to all known terminators? sometimes they may
// have side-effects.
continue;
} else if (isa<xla_hlo::DotOp>(op) || isa<xla_hlo::ConvOp>(op)) {
// Some unfusable ops must remain on their own.
return false;
} else if (isa<xla_hlo::ReduceOp>(op) ||
isa<xla_hlo::ReduceWindowOp>(op)) {
// Reductions always become flow ops.
return false;
// TODO: Properly handle region side effects.
} else if (!MemoryEffectOpInterface::hasNoEffect(&op) ||
op.getNumRegions() != 0) {
// Ops with side-effects cannot be dispatched as we must be able to
// exactly model I/O.
return false;
}
}
}
// All cases not handled above are (probably) dispatchable. This makes what we
// do here a blocklist, though as we move towards more frontend dialects that
// may not be the best idea.
return true;
}
void Dispatchability::walkDispatchableOps(
function_ref<void(FuncOp funcOp)> fn) {
for (auto funcOp : funcClones_) {
fn(funcOp.second);
}
}
bool Dispatchability::isDispatchable(StringRef funcName) {
return funcDispatchability_[funcName];
}
bool Dispatchability::isDispatchable(FuncOp funcOp) {
return isDispatchable(funcOp.getName());
}
bool Dispatchability::isInvalidated(
const AnalysisManager::PreservedAnalyses &pa) {
return false;
}
} // namespace iree_compiler
} // namespace mlir