blob: ae0fb52ad962f0ff8d533f28de7b752ddba6d739 [file] [log] [blame]
// Copyright 2019 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "compiler/Translation/Interpreter/InterpreterExecutableTranslation.h"
#include <cstdint>
#include <iostream>
#include <vector>
#include "flatbuffers/flatbuffers.h"
#include "flatbuffers/minireflect.h"
#include "compiler/IR/ConfigOps.h"
#include "compiler/IR/Interpreter/OpWriters.h"
#include "compiler/IR/Types.h"
#include "compiler/Serialization/VMFunctionBuilder.h"
#include "compiler/Serialization/VMFunctionTableBuilder.h"
#include "compiler/Serialization/VMModuleBuilder.h"
#include "compiler/Transforms/Interpreter/Passes.h"
#include "compiler/Transforms/Passes.h"
#include "compiler/Utils/Macros.h"
#include "compiler/Utils/OpUtils.h"
#include "compiler/Utils/TranslationUtils.h"
#include "schemas/executable_def_generated.h"
#include "schemas/module_def_generated.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/Support/Debug.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Module.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Pass/PassManager.h"
#include "mlir/Support/LogicalResult.h"
#include "mlir/Transforms/Passes.h"
#include "mlir/Translation.h"
#include "tensorflow/compiler/mlir/xla/transforms/passes.h"
namespace mlir {
namespace iree_compiler {
namespace {
// Builds a pass pipeline that optimizes and legalizes the module to the form
// expected by translation.
void buildLegalizeInputPassPipeline(PassManager *passManager) {
// Standard passes that shake out a lot of garbage.
// Some may have been run prior to translation but this ensures we are always
// in a known state.
passManager->addPass(createCanonicalizerPass());
passManager->addPass(createLoopFusionPass());
passManager->addPass(createLoopInvariantCodeMotionPass());
passManager->addPass(createMemRefDataFlowOptPass());
passManager->addPass(createCanonicalizerPass());
passManager->addPass(createSimplifyAffineStructuresPass());
passManager->addPass(createCSEPass());
passManager->addPass(createCanonicalizerPass());
// Eliminate ops we don't care about based on a lack of side-effects.
// IREE does not guarantee exception/error behavior of dead ops.
passManager->addPass(createAggressiveOpEliminationPass());
// Expand uses of tuples into independent args/results.
passManager->addPass(createConvertFromTupleCallingConventionPass());
passManager->addPass(createCanonicalizerPass());
}
// Builds a pass pipeline that converts functions to the iree_hl_interp dialect.
void buildInterpreterConversionPassPipeline(PassManager *passManager) {
// We don't need the IREE binding ops anymore, as we match the calling
// convention exactly (we're the same VM).
passManager->addPass(createMakeExecutableABIPass());
// Convert to the memref calling convention and optimize away as many
// loads and stores as we can prior to progressing.
passManager->addPass(createConvertToMemRefCallingConventionPass());
passManager->addPass(createCanonicalizerPass());
passManager->addPass(createMemRefDataFlowOptPass());
// Convert various dialects to IREE opcodes and cleanup leftover conversions.
passManager->addPass(createLowerToInterpreterDialectPass());
passManager->addPass(createCanonicalizerPass());
passManager->addPass(createAggressiveOpEliminationPass());
// Widen reduction functions (that have iree.executable.reduction attrs) to
// use their primitive IREE ops.
passManager->addPass(createExpandReductionsToOpsPass());
// Convert any uses of index to int32_t (as we explicitly don't want to
// support dynamic index width).
// This also looks for other weird types (i1, etc).
passManager->addPass(createLegalizeTypeStoragePass());
// Perform any last-minute optimizations to trim down the IR.
passManager->addPass(createAggressiveOpEliminationPass());
passManager->addPass(createCanonicalizerPass());
passManager->addPass(createLoopFusionPass());
passManager->addPass(createLoopInvariantCodeMotionPass());
passManager->addPass(createMemRefDataFlowOptPass());
passManager->addPass(createCanonicalizerPass());
passManager->addPass(createCSEPass());
passManager->addPass(createCanonicalizerPass());
// Drop all functions that are not reachable.
passManager->addPass(createDropUnreachableExecutableFunctionsPass());
}
// Builds a pass pipeline that lowers the iree_hl_interp dialect to the
// iree_ll_interp dialect and prepares for serialization.
void buildInterpreterLoweringPassPipeline(PassManager *passManager) {
// Lower iree_hl_interp -> iree_ll_interp.
passManager->addPass(createLowerInterpreterDialectPass());
// Assign ordinals used by the bytecode to reference executables and
// functions.
passManager->addPass(createAssignFunctionOrdinalsPass());
}
class InterpreterTranslator {
public:
explicit InterpreterTranslator(ExecutableTranslationOptions options)
: options_(options) {}
const ExecutableTranslationOptions &options() const { return options_; }
std::unique_ptr<iree::ExecutableDefT> translateExecutable(
IREE::ExecutableOp executableOp);
private:
LogicalResult translateExecutableModule(IREE::ExecutableOp executableOp,
ModuleOp moduleOp,
VMModuleBuilder *moduleBuilder);
LogicalResult declareFunction(FuncOp function,
VMModuleBuilder *moduleBuilder);
LogicalResult defineFunction(FuncOp function, VMModuleBuilder *moduleBuilder);
ExecutableTranslationOptions options_;
};
std::unique_ptr<iree::ExecutableDefT>
InterpreterTranslator::translateExecutable(IREE::ExecutableOp executableOp) {
auto moduleOp = executableOp.getInnerModule();
// Run all passes to go from input to the iree_ll_interp dialect.
auto executableConversionPasses =
createPassManager(moduleOp.getContext(), options());
buildLegalizeInputPassPipeline(executableConversionPasses.get());
buildInterpreterConversionPassPipeline(executableConversionPasses.get());
buildInterpreterLoweringPassPipeline(executableConversionPasses.get());
if (failed(runPassPipeline(options(), executableConversionPasses.get(),
moduleOp))) {
executableOp.emitError() << "Failed to run conversion passes";
return {};
}
// Build the module bytecode.
::flatbuffers::FlatBufferBuilder fbb;
VMModuleBuilder moduleBuilder(&fbb);
if (failed(
translateExecutableModule(executableOp, moduleOp, &moduleBuilder))) {
executableOp.emitError() << "Failed to translate executable module";
return {};
}
auto moduleDef = moduleBuilder.Finish();
if (moduleDef.IsNull()) {
moduleOp.emitError() << "Failed to verify completed module def";
return {};
}
auto bytes = moduleBuilder.Serialize(moduleDef);
if (bytes.empty()) {
moduleOp.emitError() << "Failed to serialize final module def";
return {};
}
OpBuilder builder(executableOp);
executableOp.setAttr("format", builder.getI32IntegerAttr(static_cast<int32_t>(
IREE::ExecutableFormat::IreeBytecode)));
auto executableDef = std::make_unique<iree::ExecutableDefT>();
executableDef->format =
static_cast<uint32_t>(IREE::ExecutableFormat::IreeBytecode);
executableDef->supported_features = iree::ExecutableFeature::kDebugging;
executableDef->contents = std::move(bytes);
return executableDef;
}
LogicalResult InterpreterTranslator::translateExecutableModule(
IREE::ExecutableOp executableOp, ModuleOp moduleOp,
VMModuleBuilder *moduleBuilder) {
// Declare functions first so that we get stable indices during declaration
// (as call ops need to use the function table).
for (auto function : moduleOp.getOps<FuncOp>()) {
RETURN_IF_FAILURE(declareFunction(function, moduleBuilder));
}
// Define functions now that all functions have been declared.
for (auto function : moduleOp.getOps<FuncOp>()) {
RETURN_IF_FAILURE(defineFunction(function, moduleBuilder));
}
return success();
}
LogicalResult InterpreterTranslator::declareFunction(
FuncOp function, VMModuleBuilder *moduleBuilder) {
auto *functionTable = moduleBuilder->function_table();
if (functionTable->IsFunctionDeclared(function)) {
// Already declared.
return success();
}
LinkageType linkageType;
if (function.isExternal()) {
linkageType = LinkageType::kImport;
} else if (function.getAttr("iree.executable.export")) {
linkageType = LinkageType::kExport;
} else {
linkageType = LinkageType::kInternal;
}
if (failed(functionTable->DeclareFunction(function, linkageType))) {
return function.emitError() << "Unable to declare function";
}
// Import functions must have their definition defined here so we get their
// type. Internal and export functions will be defined during conversion.
if (linkageType == LinkageType::kImport) {
VMFunctionBuilder functionBuilder(function, moduleBuilder->function_table(),
moduleBuilder->fbb());
auto functionOffset = functionBuilder.Finish();
if (functionOffset.IsNull()) {
return function.emitError()
<< "Failed to create import function bytecode";
}
RETURN_IF_FAILURE(
functionTable->DefineFunction(function, functionOffset, {}));
}
return success();
}
LogicalResult InterpreterTranslator::defineFunction(
FuncOp function, VMModuleBuilder *moduleBuilder) {
VMFunctionBuilder functionBuilder(function, moduleBuilder->function_table(),
moduleBuilder->fbb());
registerInterpreterCustomWriters(&functionBuilder);
RETURN_IF_FAILURE(functionBuilder.ConvertBytecode());
auto functionOffset = functionBuilder.Finish();
if (functionOffset.IsNull()) {
return function.emitError() << "Failed to serialize function";
}
RETURN_IF_FAILURE(moduleBuilder->function_table()->DefineFunction(
function, functionOffset, functionBuilder.source_map()));
return success();
}
} // namespace
llvm::Optional<ExecutableTranslationResult>
translateExecutableToInterpreterExecutable(
ArrayRef<IREE::ExecutableOp> executableOps,
ExecutableTranslationOptions options) {
InterpreterTranslator translator(options);
ExecutableTranslationResult translationResult;
for (auto executableOp : llvm::make_early_inc_range(executableOps)) {
auto executableDef = translator.translateExecutable(executableOp);
if (!executableDef) {
executableOp.emitError() << "Failed to translate one or more executables";
return llvm::None;
}
translationResult.executable_defs.push_back(std::move(executableDef));
}
return translationResult;
}
static ExecutableTranslationRegistration
InterpreterExecutableTranslationRegistration(
"interpreter-bytecode", translateExecutableToInterpreterExecutable);
} // namespace iree_compiler
} // namespace mlir