Add support to benchmark in batch mode. (#5012)

If `--iree-hal-benchmark-dispatch-repeat-count` is set during a translation, all the
hal.command_buffer.dispatch ops will be run
`iree-hal-benchmark-dispatch-repeat-count` times. In this context, you have to pass
the same number on `--batch_size` to iree-benchmark-module.

This is helpful to amortize overhead when the execution time is small.

Fixes https://github.com/google/iree/issues/3995
diff --git a/iree/compiler/Dialect/HAL/Transforms/BUILD b/iree/compiler/Dialect/HAL/Transforms/BUILD
index 6624e05..007ae40 100644
--- a/iree/compiler/Dialect/HAL/Transforms/BUILD
+++ b/iree/compiler/Dialect/HAL/Transforms/BUILD
@@ -21,6 +21,7 @@
 cc_library(
     name = "Transforms",
     srcs = [
+        "BenchmarkBatchDispatches.cpp",
         "CSEVariableLoads.cpp",
         "ConvertToHAL.cpp",
         "IdentifyConstantPools.cpp",
diff --git a/iree/compiler/Dialect/HAL/Transforms/BenchmarkBatchDispatches.cpp b/iree/compiler/Dialect/HAL/Transforms/BenchmarkBatchDispatches.cpp
new file mode 100644
index 0000000..383f7bd
--- /dev/null
+++ b/iree/compiler/Dialect/HAL/Transforms/BenchmarkBatchDispatches.cpp
@@ -0,0 +1,71 @@
+// Copyright 2021 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//      https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "iree/compiler/Dialect/HAL/IR/HALDialect.h"
+#include "iree/compiler/Dialect/HAL/IR/HALOps.h"
+#include "iree/compiler/Dialect/HAL/IR/HALTypes.h"
+#include "mlir/Dialect/StandardOps/IR/Ops.h"
+#include "mlir/Pass/Pass.h"
+
+namespace mlir {
+namespace iree_compiler {
+namespace IREE {
+namespace HAL {
+namespace {
+
+// A pass converting the IREE flow dialect into the IREE HAL dialect.
+class BenchmarkBatchDispatchesPass
+    : public PassWrapper<BenchmarkBatchDispatchesPass, OperationPass<FuncOp>> {
+ public:
+  explicit BenchmarkBatchDispatchesPass(unsigned repeatCount)
+      : repeatCount_(repeatCount) {}
+
+  void getDependentDialects(DialectRegistry& registry) const override {
+    registry.insert<HALDialect, StandardOpsDialect>();
+  }
+
+  void runOnOperation() override {
+    FuncOp f = getOperation();
+    SmallVector<HAL::CommandBufferDispatchOp> ops;
+    f.walk([&](HAL::CommandBufferDispatchOp op) { ops.push_back(op); });
+
+    for (auto op : ops) {
+      OpBuilder builder(op);
+      for (unsigned i = 0; i < repeatCount_; ++i) {
+        builder.clone(*op.getOperation());
+      }
+      op.erase();
+    }
+  }
+
+ private:
+  unsigned repeatCount_;
+};
+
+}  // namespace
+
+std::unique_ptr<OperationPass<FuncOp>> createBenchmarkBatchDispatchesPass(
+    unsigned repeatCount) {
+  return std::make_unique<BenchmarkBatchDispatchesPass>(repeatCount);
+}
+
+static PassRegistration<BenchmarkBatchDispatchesPass> pass(
+    "test-iree-hal-benchmark-batch-dispatches-2-times",
+    "Test pass used for benchmarking batch dispatches analysis",
+    [] { return std::make_unique<BenchmarkBatchDispatchesPass>(2); });
+
+}  // namespace HAL
+}  // namespace IREE
+}  // namespace iree_compiler
+}  // namespace mlir
diff --git a/iree/compiler/Dialect/HAL/Transforms/CMakeLists.txt b/iree/compiler/Dialect/HAL/Transforms/CMakeLists.txt
index 80b4863..2a0d76b 100644
--- a/iree/compiler/Dialect/HAL/Transforms/CMakeLists.txt
+++ b/iree/compiler/Dialect/HAL/Transforms/CMakeLists.txt
@@ -16,6 +16,7 @@
   HDRS
     "Passes.h"
   SRCS
+    "BenchmarkBatchDispatches.cpp"
     "CSEVariableLoads.cpp"
     "ConvertToHAL.cpp"
     "IdentifyConstantPools.cpp"
diff --git a/iree/compiler/Dialect/HAL/Transforms/Passes.cpp b/iree/compiler/Dialect/HAL/Transforms/Passes.cpp
index ce10e6b..29f64a4 100644
--- a/iree/compiler/Dialect/HAL/Transforms/Passes.cpp
+++ b/iree/compiler/Dialect/HAL/Transforms/Passes.cpp
@@ -41,6 +41,12 @@
       llvm::cl::init(true)};
 };
 
+static llvm::cl::opt<unsigned> benchmarkDispatchRepeatCount{
+    "iree-hal-benchmark-dispatch-repeat-count",
+    llvm::cl::desc(
+        "The number of times to repeat each hal.command_buffer.dispatch op."),
+    llvm::cl::init(1)};
+
 }  // namespace
 
 void buildHALTransformPassPipeline(OpPassManager &passManager,
@@ -110,6 +116,10 @@
   // Inline hal.device.switch ops and memoize their queries such that we can
   // better CSE/fold dispatch logic.
   passManager.addNestedPass<FuncOp>(createInlineDeviceSwitchesPass());
+  if (benchmarkDispatchRepeatCount != 1) {
+    passManager.addNestedPass<FuncOp>(
+        createBenchmarkBatchDispatchesPass(benchmarkDispatchRepeatCount));
+  }
   passManager.addPass(createLowerAffinePass());
   passManager.addPass(createMemoizeDeviceQueriesPass());
   passManager.addNestedPass<FuncOp>(createCanonicalizerPass());
diff --git a/iree/compiler/Dialect/HAL/Transforms/Passes.h b/iree/compiler/Dialect/HAL/Transforms/Passes.h
index ef536c5..99c10f3 100644
--- a/iree/compiler/Dialect/HAL/Transforms/Passes.h
+++ b/iree/compiler/Dialect/HAL/Transforms/Passes.h
@@ -130,6 +130,11 @@
 // TODO(#1124): replace with memory side effects once supported upstream.
 std::unique_ptr<OperationPass<FuncOp>> createCSEVariableLoadsPass();
 
+// Repeats dispatches `iree-hal-repeat-dispatch-num` times, which is 1 by
+// default.
+std::unique_ptr<OperationPass<FuncOp>> createBenchmarkBatchDispatchesPass(
+    unsigned repeatCount);
+
 //===----------------------------------------------------------------------===//
 // Register all Passes
 //===----------------------------------------------------------------------===//
@@ -138,6 +143,7 @@
   registerHALTransformPassPipeline();
   auto executableOptions = getTargetOptionsFromFlags();
   createConvertToHALPass();
+  createBenchmarkBatchDispatchesPass(/*repeatCount=*/1);
   createInlineDeviceSwitchesPass();
   createMemoizeDeviceQueriesPass();
   createMaterializeInterfacesPass(executableOptions);
diff --git a/iree/compiler/Dialect/HAL/Transforms/test/benchmark_batch_dispatches.mlir b/iree/compiler/Dialect/HAL/Transforms/test/benchmark_batch_dispatches.mlir
new file mode 100644
index 0000000..45b479e
--- /dev/null
+++ b/iree/compiler/Dialect/HAL/Transforms/test/benchmark_batch_dispatches.mlir
@@ -0,0 +1,24 @@
+// RUN: iree-opt -split-input-file -test-iree-hal-benchmark-batch-dispatches-2-times %s | IreeFileCheck %s
+
+hal.variable @_executable_0 : !hal.executable
+func @multiple_reads_no_writes() {
+  %0 = hal.variable.load @_executable_0 : !hal.executable
+  %1 = hal.variable.load @_executable_0 : !hal.executable
+  %2 = hal.variable.load @_executable_0 : !hal.executable
+
+  %c1 = constant 1 : index
+  %dev = hal.ex.shared_device : !hal.device
+  %cmd = hal.command_buffer.create %dev, "OneShot", "Transfer|Dispatch" : !hal.command_buffer
+  hal.command_buffer.begin %cmd
+  hal.command_buffer.dispatch %cmd, %0, entry_point = 0, workgroup_xyz = [%c1, %c1, %c1]
+  hal.command_buffer.dispatch %cmd, %1, entry_point = 0, workgroup_xyz = [%c1, %c1, %c1]
+  hal.command_buffer.dispatch %cmd, %2, entry_point = 0, workgroup_xyz = [%c1, %c1, %c1]
+  hal.command_buffer.end %cmd
+  return
+}
+// CHECK: hal.command_buffer.dispatch %cmd, %0, entry_point = 0, workgroup_xyz = [%c1, %c1, %c1]
+// CHECK: hal.command_buffer.dispatch %cmd, %0, entry_point = 0, workgroup_xyz = [%c1, %c1, %c1]
+// CHECK: hal.command_buffer.dispatch %cmd, %1, entry_point = 0, workgroup_xyz = [%c1, %c1, %c1]
+// CHECK: hal.command_buffer.dispatch %cmd, %1, entry_point = 0, workgroup_xyz = [%c1, %c1, %c1]
+// CHECK: hal.command_buffer.dispatch %cmd, %2, entry_point = 0, workgroup_xyz = [%c1, %c1, %c1]
+// CHECK: hal.command_buffer.dispatch %cmd, %2, entry_point = 0, workgroup_xyz = [%c1, %c1, %c1]
diff --git a/iree/tools/iree-benchmark-module-main.cc b/iree/tools/iree-benchmark-module-main.cc
index b567c1d..058dc2b 100644
--- a/iree/tools/iree-benchmark-module-main.cc
+++ b/iree/tools/iree-benchmark-module-main.cc
@@ -31,6 +31,13 @@
           "File containing the module to load that contains the entry "
           "function. Defaults to stdin.");
 
+// TODO(hanchung): Extract the batch size using
+// iree_vm_function_reflection_attr.
+ABSL_FLAG(
+    int, batch_size, 1,
+    "The number of batch size, which is expected to match "
+    "iree-hal-benchmark-dispatch-repeat-count when translating the module");
+
 ABSL_FLAG(std::string, entry_function, "",
           "Name of a function contained in the module specified by module_file "
           "to run. If this is not set, all the exported functions will be "
@@ -58,15 +65,16 @@
 namespace {
 
 static void BenchmarkFunction(
-    const std::string& benchmark_name, iree_vm_context_t* context,
-    iree_vm_function_t function, iree_vm_list_t* inputs,
+    const std::string& benchmark_name, int batch_size,
+    iree_vm_context_t* context, iree_vm_function_t function,
+    iree_vm_list_t* inputs,
     const std::vector<RawSignatureParser::Description>& output_descs,
     benchmark::State& state) {
   IREE_TRACE_SCOPE_DYNAMIC(benchmark_name.c_str());
   IREE_TRACE_FRAME_MARK();
 
   // Benchmarking loop.
-  for (auto _ : state) {
+  while (state.KeepRunningBatch(batch_size)) {
     IREE_TRACE_SCOPE0("BenchmarkIteration");
     IREE_TRACE_FRAME_MARK_NAMED("Iteration");
     vm::ref<iree_vm_list_t> outputs;
@@ -83,13 +91,14 @@
     iree_vm_function_t function, iree_vm_list_t* inputs,
     const std::vector<RawSignatureParser::Description>& output_descs) {
   auto benchmark_name = "BM_" + function_name;
-  benchmark::RegisterBenchmark(benchmark_name.c_str(),
-                               [benchmark_name, context, function, inputs,
-                                output_descs](benchmark::State& state) -> void {
-                                 BenchmarkFunction(benchmark_name, context,
-                                                   function, inputs,
-                                                   output_descs, state);
-                               })
+  int batch_size = absl::GetFlag(FLAGS_batch_size);
+  benchmark::RegisterBenchmark(
+      benchmark_name.c_str(),
+      [benchmark_name, batch_size, context, function, inputs,
+       output_descs](benchmark::State& state) -> void {
+        BenchmarkFunction(benchmark_name, batch_size, context, function, inputs,
+                          output_descs, state);
+      })
       // By default only the main thread is included in CPU time. Include all
       // the threads instead.
       ->MeasureProcessCPUTime()