blob: 61504aa6c534d96b54cbb1ef4be4e4714fee0471 [file] [log] [blame]
// Copyright 2020 The IREE Authors
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
#include "iree/compiler/Dialect/Flow/IR/FlowOps.h"
#include "iree/compiler/Dialect/Flow/Transforms/PassDetail.h"
#include "iree/compiler/Dialect/Flow/Transforms/Passes.h"
#include "iree/compiler/Dialect/Util/IR/UtilDialect.h"
#include "iree/compiler/Dialect/Util/IR/UtilOps.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/IR/BlockAndValueMapping.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/Pass/Pass.h"
namespace mlir {
namespace iree_compiler {
namespace IREE {
namespace Flow {
// Clones each exported functions (including those just created) with
// placeholder constant inputs instead of arguments and removes the exported
// attribute from the old functions.
// The input are provided using util.globals.
class ExportBenchmarkFuncsPass
: public ExportBenchmarkFuncsBase<ExportBenchmarkFuncsPass> {
public:
void getDependentDialects(DialectRegistry& registry) const override {
registry.insert<IREE::Util::UtilDialect>();
}
void runOnOperation() override {
auto moduleOp = getOperation();
// Gather the functions we want to wrap for benchmarking and wrap them.
// Since we are inserting new functions as part of this pass we must perform
// the wrapping for only the inputs.
SmallVector<mlir::func::FuncOp, 4> entryFuncOps;
for (auto entryFuncOp : moduleOp.getOps<mlir::func::FuncOp>()) {
if (entryFuncOp.isPublic()) {
entryFuncOps.push_back(entryFuncOp);
}
}
for (auto entryFuncOp : entryFuncOps) {
if (failed(createEntryPointBenchmarkFunc(moduleOp, entryFuncOp))) {
signalPassFailure();
return;
}
}
}
private:
IREE::Util::GlobalOp createDummyInputVariableOp(Location loc, Type inputType,
OpBuilder& moduleBuilder) {
std::string baseName = "_benchmark_input_";
std::string name = baseName + std::to_string(uniqueId++);
auto initialValue = moduleBuilder.getZeroAttr(inputType);
if (!initialValue) {
mlir::emitError(loc) << "unsupported function argument type: "
<< inputType;
return {};
}
auto globalOp = moduleBuilder.create<IREE::Util::GlobalOp>(
loc, name,
/*isMutable=*/false, inputType, initialValue);
globalOp.setPrivate();
globalOp->setAttr("noinline", UnitAttr::get(moduleBuilder.getContext()));
return globalOp;
}
LogicalResult createEntryPointBenchmarkFunc(mlir::ModuleOp moduleOp,
mlir::func::FuncOp entryFuncOp) {
OpBuilder moduleBuilder(&getContext());
moduleBuilder.setInsertionPointAfter(entryFuncOp);
// Create one dummy input variable per input.
Location loc = entryFuncOp.getLoc();
SmallVector<IREE::Util::GlobalOp, 4> dummyInputVariableOps;
for (auto inputType : entryFuncOp.getFunctionType().getInputs()) {
auto dummyVar = createDummyInputVariableOp(loc, inputType, moduleBuilder);
if (!dummyVar) return failure();
dummyInputVariableOps.push_back(dummyVar);
}
// Create a `() -> ()` entry point op the benchmark tool can run.
std::string funcName = std::string(entryFuncOp.getName()) + "_benchmark";
auto funcOp = moduleBuilder.create<mlir::func::FuncOp>(
loc, funcName, moduleBuilder.getFunctionType({}, {}));
funcOp.setPublic();
funcOp->setAttr("iree.abi.stub", moduleBuilder.getUnitAttr());
SmallVector<NamedAttribute> reflectionAttrs = {
moduleBuilder.getNamedAttr("benchmark",
moduleBuilder.getStringAttr("entry")),
};
funcOp->setAttr("iree.reflection",
moduleBuilder.getDictionaryAttr(reflectionAttrs));
Block* block = funcOp.addEntryBlock();
// Call the existing function with dummy arguments.
auto blockBuilder = OpBuilder::atBlockBegin(block);
SmallVector<Value, 4> args;
for (int i = 0, e = entryFuncOp.getNumArguments(); i < e; ++i) {
args.push_back(blockBuilder.createOrFold<IREE::Util::GlobalLoadOp>(
loc, dummyInputVariableOps[i]));
}
auto callOp =
blockBuilder.create<mlir::func::CallOp>(loc, entryFuncOp, args);
// Sink all results with do_not_optimize to ensure that DCE does not
// remove the call.
for (auto result : callOp.getResults()) {
blockBuilder.create<IREE::Util::DoNotOptimizeOp>(loc, result);
}
blockBuilder.create<mlir::func::ReturnOp>(loc);
// Ensure the original function is not exported and not inlined.
entryFuncOp->setAttr("noinline", moduleBuilder.getUnitAttr());
entryFuncOp->removeAttr("iree.reflection");
entryFuncOp.setPrivate();
return success();
}
int uniqueId = 0;
};
std::unique_ptr<OperationPass<mlir::ModuleOp>>
createExportBenchmarkFuncsPass() {
return std::make_unique<ExportBenchmarkFuncsPass>();
}
} // namespace Flow
} // namespace IREE
} // namespace iree_compiler
} // namespace mlir