Add a translation to export all dispatch functions. (#3431)

Fixes https://github.com/google/iree/issues/3384
diff --git a/iree/compiler/Dialect/Flow/Transforms/BUILD b/iree/compiler/Dialect/Flow/Transforms/BUILD
index c21ed77..6f3ea17 100644
--- a/iree/compiler/Dialect/Flow/Transforms/BUILD
+++ b/iree/compiler/Dialect/Flow/Transforms/BUILD
@@ -21,6 +21,7 @@
 cc_library(
     name = "Transforms",
     srcs = [
+        "CreateFuncsToInvokeExecOps.cpp",
         "DispatchConfig.cpp",
         "DispatchabilityAnalysis.cpp",
         "FlattenTuplesInCFG.cpp",
diff --git a/iree/compiler/Dialect/Flow/Transforms/CMakeLists.txt b/iree/compiler/Dialect/Flow/Transforms/CMakeLists.txt
index d8c8698..efa46d6 100644
--- a/iree/compiler/Dialect/Flow/Transforms/CMakeLists.txt
+++ b/iree/compiler/Dialect/Flow/Transforms/CMakeLists.txt
@@ -21,6 +21,7 @@
     "DispatchConfig.h"
     "Passes.h"
   SRCS
+    "CreateFuncsToInvokeExecOps.cpp"
     "DispatchConfig.cpp"
     "DispatchabilityAnalysis.cpp"
     "FlattenTuplesInCFG.cpp"
diff --git a/iree/compiler/Dialect/Flow/Transforms/CreateFuncsToInvokeExecOps.cpp b/iree/compiler/Dialect/Flow/Transforms/CreateFuncsToInvokeExecOps.cpp
new file mode 100644
index 0000000..6e5c258
--- /dev/null
+++ b/iree/compiler/Dialect/Flow/Transforms/CreateFuncsToInvokeExecOps.cpp
@@ -0,0 +1,81 @@
+// Copyright 2020 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//      https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "iree/compiler/Dialect/Flow/IR/FlowOps.h"
+#include "mlir/Dialect/StandardOps/IR/Ops.h"
+#include "mlir/IR/BlockAndValueMapping.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/StandardTypes.h"
+#include "mlir/Pass/Pass.h"
+
+namespace mlir {
+namespace iree_compiler {
+namespace IREE {
+namespace Flow {
+
+// Walks through all the execuatable ops and creates Funcs to invoke them. The
+// input are provided using constants.
+class CreateFuncsToInvokeExecOpsPass
+    : public PassWrapper<CreateFuncsToInvokeExecOpsPass,
+                         OperationPass<ModuleOp>> {
+ public:
+  CreateFuncsToInvokeExecOpsPass() = default;
+
+  void runOnOperation() override {
+    ModuleOp moduleOp = getOperation();
+    auto builder = OpBuilder::atBlockBegin(moduleOp.getBody());
+    Location loc = moduleOp.getLoc();
+    auto execOps = moduleOp.getOps<IREE::Flow::ExecutableOp>();
+    for (auto execOp : execOps) {
+      for (auto& op : execOp.getBlock()) {
+        if (auto dispatchEntryOp = dyn_cast<IREE::Flow::DispatchEntryOp>(op)) {
+          auto execFuncOp = execOp.getInnerModule().lookupSymbol<FuncOp>(
+              dispatchEntryOp.function_ref());
+          std::string funcName = std::string(execFuncOp.getName()) + "_entry";
+          auto funcType =
+              builder.getFunctionType({}, execFuncOp.getType().getResults());
+          auto funcOp =
+              builder.create<FuncOp>(moduleOp.getLoc(), funcName, funcType);
+          funcOp.setAttr("iree.module.export", UnitAttr::get(&getContext()));
+          Block* block = funcOp.addEntryBlock();
+          auto blockBuilder = OpBuilder(block, block->begin());
+          SmallVector<Value, 4> args;
+          for (auto inputType : execFuncOp.getType().getInputs()) {
+            // TODO(hanchung): Use non-zero or random values as inputs.
+            auto attr = blockBuilder.getZeroAttr(inputType);
+            auto cst = blockBuilder.create<ConstantOp>(moduleOp.getLoc(),
+                                                       inputType, attr);
+            args.push_back(cst);
+          }
+          // TODO(hanchung): Use a real workload instead? We can probably
+          // calculate the workload from the results.
+          auto dummyWorkload = blockBuilder.create<ConstantIndexOp>(loc, 0);
+          auto dispatchOp = blockBuilder.create<DispatchOp>(
+              loc, dispatchEntryOp, dummyWorkload, funcType.getResults(), args);
+          blockBuilder.create<mlir::ReturnOp>(loc, dispatchOp.getResults());
+        }
+      }
+    }
+  }
+};
+
+std::unique_ptr<OperationPass<ModuleOp>>
+createCreateFuncsToInvokeExecOpsPass() {
+  return std::make_unique<CreateFuncsToInvokeExecOpsPass>();
+}
+
+}  // namespace Flow
+}  // namespace IREE
+}  // namespace iree_compiler
+}  // namespace mlir
diff --git a/iree/compiler/Dialect/Flow/Transforms/OutlineLargeConstantsPass.cpp b/iree/compiler/Dialect/Flow/Transforms/OutlineLargeConstantsPass.cpp
index 5f2d2cd..98bb36c 100644
--- a/iree/compiler/Dialect/Flow/Transforms/OutlineLargeConstantsPass.cpp
+++ b/iree/compiler/Dialect/Flow/Transforms/OutlineLargeConstantsPass.cpp
@@ -29,20 +29,16 @@
 namespace IREE {
 namespace Flow {
 
-// NOTE: a total guess :) this feels like about the most per-dispatch-buffer
-// data we'd want to embed in the command buffer.
-// TODO(benvanik): make a pass option so users can override.
-static constexpr size_t kMinLargeConstantSize = 256;
-
 // Returns true if |constantOp| is large enough to be considered for pooling.
 // Some constants are small enough that inlining them into the ringbuffer is
 // more efficient and fewer bindings.
-static bool isConstantLarge(ConstantOp constantOp) {
+static bool isConstantLarge(ConstantOp constantOp,
+                            size_t minLargeConstantSize) {
   auto type = constantOp.getType();
   if (auto shapedType = type.dyn_cast<RankedTensorType>()) {
     size_t unpackedByteLength =
         (shapedType.getNumElements() * shapedType.getElementTypeBitWidth()) / 8;
-    if (unpackedByteLength >= kMinLargeConstantSize) {
+    if (unpackedByteLength >= minLargeConstantSize) {
       return true;
     }
   }
@@ -52,12 +48,13 @@
 // Returns a list of all large constants in the module.
 // Only walks top-level functions and ops to avoid pulling constants out of
 // executables.
-static std::vector<ConstantOp> findLargeConstantsInModule(ModuleOp moduleOp) {
+static std::vector<ConstantOp> findLargeConstantsInModule(
+    ModuleOp moduleOp, size_t minLargeConstantSize) {
   std::vector<ConstantOp> largeConstantOps;
   for (auto funcOp : moduleOp.getOps<FuncOp>()) {
     for (auto &block : funcOp.getBlocks()) {
       for (auto constantOp : block.getOps<ConstantOp>()) {
-        if (isConstantLarge(constantOp)) {
+        if (isConstantLarge(constantOp, minLargeConstantSize)) {
           largeConstantOps.push_back(constantOp);
         }
       }
@@ -69,6 +66,10 @@
 class OutlineLargeConstantsPass
     : public PassWrapper<OutlineLargeConstantsPass, OperationPass<ModuleOp>> {
  public:
+  OutlineLargeConstantsPass() = default;
+  OutlineLargeConstantsPass(size_t minLargeConstantSize)
+      : minLargeConstantSize(minLargeConstantSize){};
+
   void getDependentDialects(DialectRegistry &registry) const override {
     registry.insert<IREE::Flow::FlowDialect>();
   }
@@ -84,7 +85,8 @@
     // Create all top-level flow.variables from large constants in the module.
     OpBuilder moduleBuilder(&moduleOp.getBody()->front());
     std::vector<std::pair<ConstantOp, IREE::Flow::VariableOp>> replacements;
-    for (auto &largeConstantOp : findLargeConstantsInModule(moduleOp)) {
+    for (auto &largeConstantOp :
+         findLargeConstantsInModule(moduleOp, minLargeConstantSize)) {
       std::string name;
       do {
         name = baseName + std::to_string(uniqueId++);
@@ -115,15 +117,23 @@
       constantOp.erase();
     }
   }
+
+ private:
+  size_t minLargeConstantSize;
 };
 
-std::unique_ptr<OperationPass<ModuleOp>> createOutlineLargeConstantsPass() {
-  return std::make_unique<OutlineLargeConstantsPass>();  // NOLINT
+std::unique_ptr<OperationPass<ModuleOp>> createOutlineLargeConstantsPass(
+    size_t minLargeConstantSize) {
+  return std::make_unique<OutlineLargeConstantsPass>(
+      minLargeConstantSize);  // NOLINT
 }
 
 static PassRegistration<OutlineLargeConstantsPass> pass(
     "iree-flow-outline-large-constants",
-    "Outlines large tensor constants into flow.variables at the module level.");
+    "Outlines large tensor constants into flow.variables at the module level.",
+    [] {
+      return std::make_unique<OutlineLargeConstantsPass>(kMinLargeConstantSize);
+    });
 
 }  // namespace Flow
 }  // namespace IREE
diff --git a/iree/compiler/Dialect/Flow/Transforms/Passes.cpp b/iree/compiler/Dialect/Flow/Transforms/Passes.cpp
index c2f821f..384f102 100644
--- a/iree/compiler/Dialect/Flow/Transforms/Passes.cpp
+++ b/iree/compiler/Dialect/Flow/Transforms/Passes.cpp
@@ -214,6 +214,28 @@
       });
 }
 
+void buildExportDispatchesTransformPassPipeline(OpPassManager &passManager) {
+  passManager.addPass(IREE::Flow::createCreateFuncsToInvokeExecOpsPass());
+  // Move all the constants to flow.variables.
+  passManager.addPass(createOutlineLargeConstantsPass(
+      /*minLargeConstantSize=*/0));
+  passManager.addPass(IREE::Flow::createMaterializeExportedReflection());
+  passManager.addPass(IREE::Flow::createMergeExportedReflection());
+  passManager.addPass(IREE::Flow::createFormStreamsPass());
+  passManager.addNestedPass<FuncOp>(createCanonicalizerPass());
+  passManager.addNestedPass<FuncOp>(createCSEPass());
+  passManager.addPass(createSymbolDCEPass());
+}
+
+void registerExportDispatchesTransformPassPipeline() {
+  PassPipelineRegistration<> transformPassPipeline(
+      "iree-flow-export-dispatches",
+      "Runs the pipeline to export dispatch functions",
+      [](OpPassManager &passManager) {
+        buildExportDispatchesTransformPassPipeline(passManager);
+      });
+}
+
 }  // namespace Flow
 }  // namespace IREE
 }  // namespace iree_compiler
diff --git a/iree/compiler/Dialect/Flow/Transforms/Passes.h b/iree/compiler/Dialect/Flow/Transforms/Passes.h
index 0e432cb..edb511a 100644
--- a/iree/compiler/Dialect/Flow/Transforms/Passes.h
+++ b/iree/compiler/Dialect/Flow/Transforms/Passes.h
@@ -46,6 +46,15 @@
 
 void registerFlowTransformPassPipeline();
 
+// Adds a set of passes to the given pass manager that run the flow transforms
+// to export dispatch functions.
+//
+// The expected usage is to add passes right after
+// buildFlowTransformPassPipieline.
+void buildExportDispatchesTransformPassPipeline(OpPassManager &passManager);
+
+void registerExportDispatchesTransformPassPipeline();
+
 //===----------------------------------------------------------------------===//
 // Input canonicalization and legalization
 //===----------------------------------------------------------------------===//
@@ -116,6 +125,9 @@
 // Outlines dispatch regions into executables.
 std::unique_ptr<OperationPass<ModuleOp>> createOutlineDispatchRegionsPass();
 
+// Exports all the dispatch functions to the module.
+std::unique_ptr<OperationPass<ModuleOp>> createCreateFuncsToInvokeExecOpsPass();
+
 //===----------------------------------------------------------------------===//
 // Optimizations
 //===----------------------------------------------------------------------===//
@@ -124,7 +136,12 @@
 // shaped, adjusting types, etc).
 
 // Outlines large tensor constants into flow.variables at the module level.
-std::unique_ptr<OperationPass<ModuleOp>> createOutlineLargeConstantsPass();
+//
+// NOTE: a total guess :) this feels like about the most per-dispatch-buffer
+// data we'd want to embed in the command buffer.
+static constexpr size_t kMinLargeConstantSize = 256;
+std::unique_ptr<OperationPass<ModuleOp>> createOutlineLargeConstantsPass(
+    size_t minLargeConstantSize = kMinLargeConstantSize);
 
 //===----------------------------------------------------------------------===//
 // Stream Formation and Folding
@@ -148,6 +165,7 @@
 
 inline void registerFlowPasses() {
   registerFlowTransformPassPipeline();
+  registerExportDispatchesTransformPassPipeline();
   createFlattenTuplesInCFGPass();
   createLegalizeInputTypesPass();
   createHLOPreprocessingPass();
@@ -162,6 +180,7 @@
   createRematerializeDispatchConstantsPass();
   createOutlineDispatchRegionsPass();
   createOutlineLargeConstantsPass();
+  createCreateFuncsToInvokeExecOpsPass();
   createFormStreamsPass();
   createHoistUnstreamableOpsPass();
 }
diff --git a/iree/compiler/Dialect/Flow/Transforms/test/create_funcs_to_invoke_exec_ops.mlir b/iree/compiler/Dialect/Flow/Transforms/test/create_funcs_to_invoke_exec_ops.mlir
new file mode 100644
index 0000000..4f297c6
--- /dev/null
+++ b/iree/compiler/Dialect/Flow/Transforms/test/create_funcs_to_invoke_exec_ops.mlir
@@ -0,0 +1,24 @@
+// RUN: iree-opt -iree-flow-transformation-pipeline -iree-flow-export-dispatches %s | IreeFileCheck %s
+
+module {
+  func @two_dispatch(%arg0: tensor<5x3xf32>, %arg1: tensor<3x5xf32>) -> (tensor<5x5xf32>, tensor<3x5xf32>) attributes { iree.module.export } {
+    %0 = "mhlo.dot"(%arg0, %arg1) : (tensor<5x3xf32>, tensor<3x5xf32>) -> tensor<5x5xf32>
+    %1 = "mhlo.dot"(%arg1, %0) : (tensor<3x5xf32>, tensor<5x5xf32>) -> tensor<3x5xf32>
+    return %0, %1 : tensor<5x5xf32>, tensor<3x5xf32>
+  }
+}
+// CHECK: func @two_dispatch_ex_dispatch_0_entry
+// CHECK: %{{.+}} = flow.variable.load {{.*}} : tensor<5x3xf32>
+// CHECK: %{{.+}} = flow.variable.load {{.*}} : tensor<3x5xf32>
+// CHECK: %[[RES:.+]] = flow.ex.stream.fragment({{.+}}) -> tensor<5x5xf32> {
+// CHECK:   %[[DISPATCH_RES:.+]] = flow.dispatch @two_dispatch_ex_dispatch_0::@two_dispatch_ex_dispatch_0[%{{.+}} : index](%{{.+}}, %{{.+}}) : (tensor<5x3xf32>, tensor<3x5xf32>) -> tensor<5x5xf32>
+// CHECK:   flow.return %[[DISPATCH_RES]] : tensor<5x5xf32>
+// CHECK: return %[[RES]] : tensor<5x5xf32>
+//
+// CHECK: func @two_dispatch_ex_dispatch_1_entry
+// CHECK: %[[ARG0:.+]] = flow.variable.load {{.*}} : tensor<3x5xf32>
+// CHECK: %[[ARG1:.+]] = flow.variable.load {{.*}} : tensor<5x5xf32>
+// CHECK: %[[RES:.+]] = flow.ex.stream.fragment({{.+}}) -> tensor<3x5xf32>
+// CHECK:   %[[DISPATCH_RES:.+]] = flow.dispatch @two_dispatch_ex_dispatch_1::@two_dispatch_ex_dispatch_1[%{{.+}} : index](%{{.+}}, %{{.+}}) : (tensor<3x5xf32>, tensor<5x5xf32>) -> tensor<3x5xf32>
+// CHECK:   flow.return %[[DISPATCH_RES]] : tensor<3x5xf32>
+// CHECK: return %[[RES]] : tensor<3x5xf32>
diff --git a/iree/compiler/Translation/IREEVM.cpp b/iree/compiler/Translation/IREEVM.cpp
index de6b124..353e4ed 100644
--- a/iree/compiler/Translation/IREEVM.cpp
+++ b/iree/compiler/Translation/IREEVM.cpp
@@ -81,13 +81,17 @@
 
 static LogicalResult translateFromMLIRToVM(
     ModuleOp moduleOp, IREE::HAL::TargetOptions executableOptions,
-    IREE::VM::TargetOptions targetOptions) {
+    IREE::VM::TargetOptions targetOptions,
+    bool addExportDispatchesPipeline = false) {
   // Convert from our source to a vm.module in canonical form.
   // After this completes we have a non-bytecode-specific vm.module that we
   // could lower to other forms (LLVM IR, C, etc).
   PassManager passManager(moduleOp.getContext());
   mlir::applyPassManagerCLOptions(passManager);
   IREE::Flow::buildFlowTransformPassPipeline(passManager);
+  if (addExportDispatchesPipeline) {
+    IREE::Flow::buildExportDispatchesTransformPassPipeline(passManager);
+  }
   IREE::HAL::buildHALTransformPassPipeline(passManager, executableOptions);
   IREE::VM::buildVMTransformPassPipeline(passManager, targetOptions);
   passManager.addPass(mlir::iree_compiler::IREE::createDropCompilerHintsPass());
@@ -99,12 +103,13 @@
 }
 
 LogicalResult translateFromMLIRToVMBytecodeModule(
-    ModuleOp moduleOp, IREE::HAL::TargetOptions executableOptions,
+    ModuleOp moduleOp, bool addExportDispatchesPipeline,
+    IREE::HAL::TargetOptions executableOptions,
     IREE::VM::TargetOptions targetOptions,
     IREE::VM::BytecodeTargetOptions bytecodeOptions,
     llvm::raw_ostream &output) {
-  auto result =
-      translateFromMLIRToVM(moduleOp, executableOptions, targetOptions);
+  auto result = translateFromMLIRToVM(
+      moduleOp, executableOptions, targetOptions, addExportDispatchesPipeline);
   if (failed(result)) {
     return result;
   }
@@ -119,9 +124,20 @@
   auto halTargetOptions = IREE::HAL::getTargetOptionsFromFlags();
   auto vmTargetOptions = IREE::VM::getTargetOptionsFromFlags();
   auto bytecodeTargetOptions = IREE::VM::getBytecodeTargetOptionsFromFlags();
-  return translateFromMLIRToVMBytecodeModule(moduleOp, halTargetOptions,
-                                             vmTargetOptions,
-                                             bytecodeTargetOptions, output);
+  return translateFromMLIRToVMBytecodeModule(
+      moduleOp, /*addExportDispatchesPipeline=*/false, halTargetOptions,
+      vmTargetOptions, bytecodeTargetOptions, output);
+}
+
+static LogicalResult translateFromMLIRToBenchmarkVMBytecodeModuleWithFlags(
+    ModuleOp moduleOp, llvm::raw_ostream &output) {
+  mlir::registerPassManagerCLOptions();
+  auto halTargetOptions = IREE::HAL::getTargetOptionsFromFlags();
+  auto vmTargetOptions = IREE::VM::getTargetOptionsFromFlags();
+  auto bytecodeTargetOptions = IREE::VM::getBytecodeTargetOptionsFromFlags();
+  return translateFromMLIRToVMBytecodeModule(
+      moduleOp, /*addExportDispatchesPipeline=*/true, halTargetOptions,
+      vmTargetOptions, bytecodeTargetOptions, output);
 }
 
 #ifdef IREE_HAVE_EMITC_DIALECT
@@ -153,6 +169,10 @@
       "iree-mlir-to-vm-bytecode-module",
       translateFromMLIRToVMBytecodeModuleWithFlags);
 
+  TranslateFromMLIRRegistration toBenchmarkVMBytecodeModuleWithFlags(
+      "iree-mlir-to-benchmark-vm-bytecode-module",
+      translateFromMLIRToBenchmarkVMBytecodeModuleWithFlags);
+
 #ifdef IREE_HAVE_EMITC_DIALECT
   TranslateFromMLIRRegistration toVMCModuleWithFlags(
       "iree-mlir-to-vm-c-module", translateFromMLIRToVMCModuleWithFlags);
diff --git a/iree/compiler/Translation/IREEVM.h b/iree/compiler/Translation/IREEVM.h
index 5179d10..2bd624e 100644
--- a/iree/compiler/Translation/IREEVM.h
+++ b/iree/compiler/Translation/IREEVM.h
@@ -52,7 +52,8 @@
 //
 // Exposed via the --iree-mlir-to-vm-bytecode-module translation.
 LogicalResult translateFromMLIRToVMBytecodeModule(
-    ModuleOp moduleOp, IREE::HAL::TargetOptions executableOptions,
+    ModuleOp moduleOp, bool addExportDispatchesPipeline,
+    IREE::HAL::TargetOptions executableOptions,
     IREE::VM::TargetOptions targetOptions,
     IREE::VM::BytecodeTargetOptions bytecodeOptions, llvm::raw_ostream &output);