Add a translation to export all dispatch functions. (#3431)
Fixes https://github.com/google/iree/issues/3384
diff --git a/iree/compiler/Dialect/Flow/Transforms/BUILD b/iree/compiler/Dialect/Flow/Transforms/BUILD
index c21ed77..6f3ea17 100644
--- a/iree/compiler/Dialect/Flow/Transforms/BUILD
+++ b/iree/compiler/Dialect/Flow/Transforms/BUILD
@@ -21,6 +21,7 @@
cc_library(
name = "Transforms",
srcs = [
+ "CreateFuncsToInvokeExecOps.cpp",
"DispatchConfig.cpp",
"DispatchabilityAnalysis.cpp",
"FlattenTuplesInCFG.cpp",
diff --git a/iree/compiler/Dialect/Flow/Transforms/CMakeLists.txt b/iree/compiler/Dialect/Flow/Transforms/CMakeLists.txt
index d8c8698..efa46d6 100644
--- a/iree/compiler/Dialect/Flow/Transforms/CMakeLists.txt
+++ b/iree/compiler/Dialect/Flow/Transforms/CMakeLists.txt
@@ -21,6 +21,7 @@
"DispatchConfig.h"
"Passes.h"
SRCS
+ "CreateFuncsToInvokeExecOps.cpp"
"DispatchConfig.cpp"
"DispatchabilityAnalysis.cpp"
"FlattenTuplesInCFG.cpp"
diff --git a/iree/compiler/Dialect/Flow/Transforms/CreateFuncsToInvokeExecOps.cpp b/iree/compiler/Dialect/Flow/Transforms/CreateFuncsToInvokeExecOps.cpp
new file mode 100644
index 0000000..6e5c258
--- /dev/null
+++ b/iree/compiler/Dialect/Flow/Transforms/CreateFuncsToInvokeExecOps.cpp
@@ -0,0 +1,81 @@
+// Copyright 2020 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "iree/compiler/Dialect/Flow/IR/FlowOps.h"
+#include "mlir/Dialect/StandardOps/IR/Ops.h"
+#include "mlir/IR/BlockAndValueMapping.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/StandardTypes.h"
+#include "mlir/Pass/Pass.h"
+
+namespace mlir {
+namespace iree_compiler {
+namespace IREE {
+namespace Flow {
+
+// Walks through all the execuatable ops and creates Funcs to invoke them. The
+// input are provided using constants.
+class CreateFuncsToInvokeExecOpsPass
+ : public PassWrapper<CreateFuncsToInvokeExecOpsPass,
+ OperationPass<ModuleOp>> {
+ public:
+ CreateFuncsToInvokeExecOpsPass() = default;
+
+ void runOnOperation() override {
+ ModuleOp moduleOp = getOperation();
+ auto builder = OpBuilder::atBlockBegin(moduleOp.getBody());
+ Location loc = moduleOp.getLoc();
+ auto execOps = moduleOp.getOps<IREE::Flow::ExecutableOp>();
+ for (auto execOp : execOps) {
+ for (auto& op : execOp.getBlock()) {
+ if (auto dispatchEntryOp = dyn_cast<IREE::Flow::DispatchEntryOp>(op)) {
+ auto execFuncOp = execOp.getInnerModule().lookupSymbol<FuncOp>(
+ dispatchEntryOp.function_ref());
+ std::string funcName = std::string(execFuncOp.getName()) + "_entry";
+ auto funcType =
+ builder.getFunctionType({}, execFuncOp.getType().getResults());
+ auto funcOp =
+ builder.create<FuncOp>(moduleOp.getLoc(), funcName, funcType);
+ funcOp.setAttr("iree.module.export", UnitAttr::get(&getContext()));
+ Block* block = funcOp.addEntryBlock();
+ auto blockBuilder = OpBuilder(block, block->begin());
+ SmallVector<Value, 4> args;
+ for (auto inputType : execFuncOp.getType().getInputs()) {
+ // TODO(hanchung): Use non-zero or random values as inputs.
+ auto attr = blockBuilder.getZeroAttr(inputType);
+ auto cst = blockBuilder.create<ConstantOp>(moduleOp.getLoc(),
+ inputType, attr);
+ args.push_back(cst);
+ }
+ // TODO(hanchung): Use a real workload instead? We can probably
+ // calculate the workload from the results.
+ auto dummyWorkload = blockBuilder.create<ConstantIndexOp>(loc, 0);
+ auto dispatchOp = blockBuilder.create<DispatchOp>(
+ loc, dispatchEntryOp, dummyWorkload, funcType.getResults(), args);
+ blockBuilder.create<mlir::ReturnOp>(loc, dispatchOp.getResults());
+ }
+ }
+ }
+ }
+};
+
+std::unique_ptr<OperationPass<ModuleOp>>
+createCreateFuncsToInvokeExecOpsPass() {
+ return std::make_unique<CreateFuncsToInvokeExecOpsPass>();
+}
+
+} // namespace Flow
+} // namespace IREE
+} // namespace iree_compiler
+} // namespace mlir
diff --git a/iree/compiler/Dialect/Flow/Transforms/OutlineLargeConstantsPass.cpp b/iree/compiler/Dialect/Flow/Transforms/OutlineLargeConstantsPass.cpp
index 5f2d2cd..98bb36c 100644
--- a/iree/compiler/Dialect/Flow/Transforms/OutlineLargeConstantsPass.cpp
+++ b/iree/compiler/Dialect/Flow/Transforms/OutlineLargeConstantsPass.cpp
@@ -29,20 +29,16 @@
namespace IREE {
namespace Flow {
-// NOTE: a total guess :) this feels like about the most per-dispatch-buffer
-// data we'd want to embed in the command buffer.
-// TODO(benvanik): make a pass option so users can override.
-static constexpr size_t kMinLargeConstantSize = 256;
-
// Returns true if |constantOp| is large enough to be considered for pooling.
// Some constants are small enough that inlining them into the ringbuffer is
// more efficient and fewer bindings.
-static bool isConstantLarge(ConstantOp constantOp) {
+static bool isConstantLarge(ConstantOp constantOp,
+ size_t minLargeConstantSize) {
auto type = constantOp.getType();
if (auto shapedType = type.dyn_cast<RankedTensorType>()) {
size_t unpackedByteLength =
(shapedType.getNumElements() * shapedType.getElementTypeBitWidth()) / 8;
- if (unpackedByteLength >= kMinLargeConstantSize) {
+ if (unpackedByteLength >= minLargeConstantSize) {
return true;
}
}
@@ -52,12 +48,13 @@
// Returns a list of all large constants in the module.
// Only walks top-level functions and ops to avoid pulling constants out of
// executables.
-static std::vector<ConstantOp> findLargeConstantsInModule(ModuleOp moduleOp) {
+static std::vector<ConstantOp> findLargeConstantsInModule(
+ ModuleOp moduleOp, size_t minLargeConstantSize) {
std::vector<ConstantOp> largeConstantOps;
for (auto funcOp : moduleOp.getOps<FuncOp>()) {
for (auto &block : funcOp.getBlocks()) {
for (auto constantOp : block.getOps<ConstantOp>()) {
- if (isConstantLarge(constantOp)) {
+ if (isConstantLarge(constantOp, minLargeConstantSize)) {
largeConstantOps.push_back(constantOp);
}
}
@@ -69,6 +66,10 @@
class OutlineLargeConstantsPass
: public PassWrapper<OutlineLargeConstantsPass, OperationPass<ModuleOp>> {
public:
+ OutlineLargeConstantsPass() = default;
+ OutlineLargeConstantsPass(size_t minLargeConstantSize)
+ : minLargeConstantSize(minLargeConstantSize){};
+
void getDependentDialects(DialectRegistry ®istry) const override {
registry.insert<IREE::Flow::FlowDialect>();
}
@@ -84,7 +85,8 @@
// Create all top-level flow.variables from large constants in the module.
OpBuilder moduleBuilder(&moduleOp.getBody()->front());
std::vector<std::pair<ConstantOp, IREE::Flow::VariableOp>> replacements;
- for (auto &largeConstantOp : findLargeConstantsInModule(moduleOp)) {
+ for (auto &largeConstantOp :
+ findLargeConstantsInModule(moduleOp, minLargeConstantSize)) {
std::string name;
do {
name = baseName + std::to_string(uniqueId++);
@@ -115,15 +117,23 @@
constantOp.erase();
}
}
+
+ private:
+ size_t minLargeConstantSize;
};
-std::unique_ptr<OperationPass<ModuleOp>> createOutlineLargeConstantsPass() {
- return std::make_unique<OutlineLargeConstantsPass>(); // NOLINT
+std::unique_ptr<OperationPass<ModuleOp>> createOutlineLargeConstantsPass(
+ size_t minLargeConstantSize) {
+ return std::make_unique<OutlineLargeConstantsPass>(
+ minLargeConstantSize); // NOLINT
}
static PassRegistration<OutlineLargeConstantsPass> pass(
"iree-flow-outline-large-constants",
- "Outlines large tensor constants into flow.variables at the module level.");
+ "Outlines large tensor constants into flow.variables at the module level.",
+ [] {
+ return std::make_unique<OutlineLargeConstantsPass>(kMinLargeConstantSize);
+ });
} // namespace Flow
} // namespace IREE
diff --git a/iree/compiler/Dialect/Flow/Transforms/Passes.cpp b/iree/compiler/Dialect/Flow/Transforms/Passes.cpp
index c2f821f..384f102 100644
--- a/iree/compiler/Dialect/Flow/Transforms/Passes.cpp
+++ b/iree/compiler/Dialect/Flow/Transforms/Passes.cpp
@@ -214,6 +214,28 @@
});
}
+void buildExportDispatchesTransformPassPipeline(OpPassManager &passManager) {
+ passManager.addPass(IREE::Flow::createCreateFuncsToInvokeExecOpsPass());
+ // Move all the constants to flow.variables.
+ passManager.addPass(createOutlineLargeConstantsPass(
+ /*minLargeConstantSize=*/0));
+ passManager.addPass(IREE::Flow::createMaterializeExportedReflection());
+ passManager.addPass(IREE::Flow::createMergeExportedReflection());
+ passManager.addPass(IREE::Flow::createFormStreamsPass());
+ passManager.addNestedPass<FuncOp>(createCanonicalizerPass());
+ passManager.addNestedPass<FuncOp>(createCSEPass());
+ passManager.addPass(createSymbolDCEPass());
+}
+
+void registerExportDispatchesTransformPassPipeline() {
+ PassPipelineRegistration<> transformPassPipeline(
+ "iree-flow-export-dispatches",
+ "Runs the pipeline to export dispatch functions",
+ [](OpPassManager &passManager) {
+ buildExportDispatchesTransformPassPipeline(passManager);
+ });
+}
+
} // namespace Flow
} // namespace IREE
} // namespace iree_compiler
diff --git a/iree/compiler/Dialect/Flow/Transforms/Passes.h b/iree/compiler/Dialect/Flow/Transforms/Passes.h
index 0e432cb..edb511a 100644
--- a/iree/compiler/Dialect/Flow/Transforms/Passes.h
+++ b/iree/compiler/Dialect/Flow/Transforms/Passes.h
@@ -46,6 +46,15 @@
void registerFlowTransformPassPipeline();
+// Adds a set of passes to the given pass manager that run the flow transforms
+// to export dispatch functions.
+//
+// The expected usage is to add passes right after
+// buildFlowTransformPassPipieline.
+void buildExportDispatchesTransformPassPipeline(OpPassManager &passManager);
+
+void registerExportDispatchesTransformPassPipeline();
+
//===----------------------------------------------------------------------===//
// Input canonicalization and legalization
//===----------------------------------------------------------------------===//
@@ -116,6 +125,9 @@
// Outlines dispatch regions into executables.
std::unique_ptr<OperationPass<ModuleOp>> createOutlineDispatchRegionsPass();
+// Exports all the dispatch functions to the module.
+std::unique_ptr<OperationPass<ModuleOp>> createCreateFuncsToInvokeExecOpsPass();
+
//===----------------------------------------------------------------------===//
// Optimizations
//===----------------------------------------------------------------------===//
@@ -124,7 +136,12 @@
// shaped, adjusting types, etc).
// Outlines large tensor constants into flow.variables at the module level.
-std::unique_ptr<OperationPass<ModuleOp>> createOutlineLargeConstantsPass();
+//
+// NOTE: a total guess :) this feels like about the most per-dispatch-buffer
+// data we'd want to embed in the command buffer.
+static constexpr size_t kMinLargeConstantSize = 256;
+std::unique_ptr<OperationPass<ModuleOp>> createOutlineLargeConstantsPass(
+ size_t minLargeConstantSize = kMinLargeConstantSize);
//===----------------------------------------------------------------------===//
// Stream Formation and Folding
@@ -148,6 +165,7 @@
inline void registerFlowPasses() {
registerFlowTransformPassPipeline();
+ registerExportDispatchesTransformPassPipeline();
createFlattenTuplesInCFGPass();
createLegalizeInputTypesPass();
createHLOPreprocessingPass();
@@ -162,6 +180,7 @@
createRematerializeDispatchConstantsPass();
createOutlineDispatchRegionsPass();
createOutlineLargeConstantsPass();
+ createCreateFuncsToInvokeExecOpsPass();
createFormStreamsPass();
createHoistUnstreamableOpsPass();
}
diff --git a/iree/compiler/Dialect/Flow/Transforms/test/create_funcs_to_invoke_exec_ops.mlir b/iree/compiler/Dialect/Flow/Transforms/test/create_funcs_to_invoke_exec_ops.mlir
new file mode 100644
index 0000000..4f297c6
--- /dev/null
+++ b/iree/compiler/Dialect/Flow/Transforms/test/create_funcs_to_invoke_exec_ops.mlir
@@ -0,0 +1,24 @@
+// RUN: iree-opt -iree-flow-transformation-pipeline -iree-flow-export-dispatches %s | IreeFileCheck %s
+
+module {
+ func @two_dispatch(%arg0: tensor<5x3xf32>, %arg1: tensor<3x5xf32>) -> (tensor<5x5xf32>, tensor<3x5xf32>) attributes { iree.module.export } {
+ %0 = "mhlo.dot"(%arg0, %arg1) : (tensor<5x3xf32>, tensor<3x5xf32>) -> tensor<5x5xf32>
+ %1 = "mhlo.dot"(%arg1, %0) : (tensor<3x5xf32>, tensor<5x5xf32>) -> tensor<3x5xf32>
+ return %0, %1 : tensor<5x5xf32>, tensor<3x5xf32>
+ }
+}
+// CHECK: func @two_dispatch_ex_dispatch_0_entry
+// CHECK: %{{.+}} = flow.variable.load {{.*}} : tensor<5x3xf32>
+// CHECK: %{{.+}} = flow.variable.load {{.*}} : tensor<3x5xf32>
+// CHECK: %[[RES:.+]] = flow.ex.stream.fragment({{.+}}) -> tensor<5x5xf32> {
+// CHECK: %[[DISPATCH_RES:.+]] = flow.dispatch @two_dispatch_ex_dispatch_0::@two_dispatch_ex_dispatch_0[%{{.+}} : index](%{{.+}}, %{{.+}}) : (tensor<5x3xf32>, tensor<3x5xf32>) -> tensor<5x5xf32>
+// CHECK: flow.return %[[DISPATCH_RES]] : tensor<5x5xf32>
+// CHECK: return %[[RES]] : tensor<5x5xf32>
+//
+// CHECK: func @two_dispatch_ex_dispatch_1_entry
+// CHECK: %[[ARG0:.+]] = flow.variable.load {{.*}} : tensor<3x5xf32>
+// CHECK: %[[ARG1:.+]] = flow.variable.load {{.*}} : tensor<5x5xf32>
+// CHECK: %[[RES:.+]] = flow.ex.stream.fragment({{.+}}) -> tensor<3x5xf32>
+// CHECK: %[[DISPATCH_RES:.+]] = flow.dispatch @two_dispatch_ex_dispatch_1::@two_dispatch_ex_dispatch_1[%{{.+}} : index](%{{.+}}, %{{.+}}) : (tensor<3x5xf32>, tensor<5x5xf32>) -> tensor<3x5xf32>
+// CHECK: flow.return %[[DISPATCH_RES]] : tensor<3x5xf32>
+// CHECK: return %[[RES]] : tensor<3x5xf32>
diff --git a/iree/compiler/Translation/IREEVM.cpp b/iree/compiler/Translation/IREEVM.cpp
index de6b124..353e4ed 100644
--- a/iree/compiler/Translation/IREEVM.cpp
+++ b/iree/compiler/Translation/IREEVM.cpp
@@ -81,13 +81,17 @@
static LogicalResult translateFromMLIRToVM(
ModuleOp moduleOp, IREE::HAL::TargetOptions executableOptions,
- IREE::VM::TargetOptions targetOptions) {
+ IREE::VM::TargetOptions targetOptions,
+ bool addExportDispatchesPipeline = false) {
// Convert from our source to a vm.module in canonical form.
// After this completes we have a non-bytecode-specific vm.module that we
// could lower to other forms (LLVM IR, C, etc).
PassManager passManager(moduleOp.getContext());
mlir::applyPassManagerCLOptions(passManager);
IREE::Flow::buildFlowTransformPassPipeline(passManager);
+ if (addExportDispatchesPipeline) {
+ IREE::Flow::buildExportDispatchesTransformPassPipeline(passManager);
+ }
IREE::HAL::buildHALTransformPassPipeline(passManager, executableOptions);
IREE::VM::buildVMTransformPassPipeline(passManager, targetOptions);
passManager.addPass(mlir::iree_compiler::IREE::createDropCompilerHintsPass());
@@ -99,12 +103,13 @@
}
LogicalResult translateFromMLIRToVMBytecodeModule(
- ModuleOp moduleOp, IREE::HAL::TargetOptions executableOptions,
+ ModuleOp moduleOp, bool addExportDispatchesPipeline,
+ IREE::HAL::TargetOptions executableOptions,
IREE::VM::TargetOptions targetOptions,
IREE::VM::BytecodeTargetOptions bytecodeOptions,
llvm::raw_ostream &output) {
- auto result =
- translateFromMLIRToVM(moduleOp, executableOptions, targetOptions);
+ auto result = translateFromMLIRToVM(
+ moduleOp, executableOptions, targetOptions, addExportDispatchesPipeline);
if (failed(result)) {
return result;
}
@@ -119,9 +124,20 @@
auto halTargetOptions = IREE::HAL::getTargetOptionsFromFlags();
auto vmTargetOptions = IREE::VM::getTargetOptionsFromFlags();
auto bytecodeTargetOptions = IREE::VM::getBytecodeTargetOptionsFromFlags();
- return translateFromMLIRToVMBytecodeModule(moduleOp, halTargetOptions,
- vmTargetOptions,
- bytecodeTargetOptions, output);
+ return translateFromMLIRToVMBytecodeModule(
+ moduleOp, /*addExportDispatchesPipeline=*/false, halTargetOptions,
+ vmTargetOptions, bytecodeTargetOptions, output);
+}
+
+static LogicalResult translateFromMLIRToBenchmarkVMBytecodeModuleWithFlags(
+ ModuleOp moduleOp, llvm::raw_ostream &output) {
+ mlir::registerPassManagerCLOptions();
+ auto halTargetOptions = IREE::HAL::getTargetOptionsFromFlags();
+ auto vmTargetOptions = IREE::VM::getTargetOptionsFromFlags();
+ auto bytecodeTargetOptions = IREE::VM::getBytecodeTargetOptionsFromFlags();
+ return translateFromMLIRToVMBytecodeModule(
+ moduleOp, /*addExportDispatchesPipeline=*/true, halTargetOptions,
+ vmTargetOptions, bytecodeTargetOptions, output);
}
#ifdef IREE_HAVE_EMITC_DIALECT
@@ -153,6 +169,10 @@
"iree-mlir-to-vm-bytecode-module",
translateFromMLIRToVMBytecodeModuleWithFlags);
+ TranslateFromMLIRRegistration toBenchmarkVMBytecodeModuleWithFlags(
+ "iree-mlir-to-benchmark-vm-bytecode-module",
+ translateFromMLIRToBenchmarkVMBytecodeModuleWithFlags);
+
#ifdef IREE_HAVE_EMITC_DIALECT
TranslateFromMLIRRegistration toVMCModuleWithFlags(
"iree-mlir-to-vm-c-module", translateFromMLIRToVMCModuleWithFlags);
diff --git a/iree/compiler/Translation/IREEVM.h b/iree/compiler/Translation/IREEVM.h
index 5179d10..2bd624e 100644
--- a/iree/compiler/Translation/IREEVM.h
+++ b/iree/compiler/Translation/IREEVM.h
@@ -52,7 +52,8 @@
//
// Exposed via the --iree-mlir-to-vm-bytecode-module translation.
LogicalResult translateFromMLIRToVMBytecodeModule(
- ModuleOp moduleOp, IREE::HAL::TargetOptions executableOptions,
+ ModuleOp moduleOp, bool addExportDispatchesPipeline,
+ IREE::HAL::TargetOptions executableOptions,
IREE::VM::TargetOptions targetOptions,
IREE::VM::BytecodeTargetOptions bytecodeOptions, llvm::raw_ostream &output);