Add support to benchmark in batch mode. (#5012)
If `--iree-hal-benchmark-dispatch-repeat-count` is set during a translation, all the
hal.command_buffer.dispatch ops will be run
`iree-hal-benchmark-dispatch-repeat-count` times. In this context, you have to pass
the same number on `--batch_size` to iree-benchmark-module.
This is helpful to amortize overhead when the execution time is small.
Fixes https://github.com/google/iree/issues/3995
diff --git a/iree/compiler/Dialect/HAL/Transforms/BUILD b/iree/compiler/Dialect/HAL/Transforms/BUILD
index 6624e05..007ae40 100644
--- a/iree/compiler/Dialect/HAL/Transforms/BUILD
+++ b/iree/compiler/Dialect/HAL/Transforms/BUILD
@@ -21,6 +21,7 @@
cc_library(
name = "Transforms",
srcs = [
+ "BenchmarkBatchDispatches.cpp",
"CSEVariableLoads.cpp",
"ConvertToHAL.cpp",
"IdentifyConstantPools.cpp",
diff --git a/iree/compiler/Dialect/HAL/Transforms/BenchmarkBatchDispatches.cpp b/iree/compiler/Dialect/HAL/Transforms/BenchmarkBatchDispatches.cpp
new file mode 100644
index 0000000..383f7bd
--- /dev/null
+++ b/iree/compiler/Dialect/HAL/Transforms/BenchmarkBatchDispatches.cpp
@@ -0,0 +1,71 @@
+// Copyright 2021 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "iree/compiler/Dialect/HAL/IR/HALDialect.h"
+#include "iree/compiler/Dialect/HAL/IR/HALOps.h"
+#include "iree/compiler/Dialect/HAL/IR/HALTypes.h"
+#include "mlir/Dialect/StandardOps/IR/Ops.h"
+#include "mlir/Pass/Pass.h"
+
+namespace mlir {
+namespace iree_compiler {
+namespace IREE {
+namespace HAL {
+namespace {
+
+// A pass converting the IREE flow dialect into the IREE HAL dialect.
+class BenchmarkBatchDispatchesPass
+ : public PassWrapper<BenchmarkBatchDispatchesPass, OperationPass<FuncOp>> {
+ public:
+ explicit BenchmarkBatchDispatchesPass(unsigned repeatCount)
+ : repeatCount_(repeatCount) {}
+
+ void getDependentDialects(DialectRegistry& registry) const override {
+ registry.insert<HALDialect, StandardOpsDialect>();
+ }
+
+ void runOnOperation() override {
+ FuncOp f = getOperation();
+ SmallVector<HAL::CommandBufferDispatchOp> ops;
+ f.walk([&](HAL::CommandBufferDispatchOp op) { ops.push_back(op); });
+
+ for (auto op : ops) {
+ OpBuilder builder(op);
+ for (unsigned i = 0; i < repeatCount_; ++i) {
+ builder.clone(*op.getOperation());
+ }
+ op.erase();
+ }
+ }
+
+ private:
+ unsigned repeatCount_;
+};
+
+} // namespace
+
+std::unique_ptr<OperationPass<FuncOp>> createBenchmarkBatchDispatchesPass(
+ unsigned repeatCount) {
+ return std::make_unique<BenchmarkBatchDispatchesPass>(repeatCount);
+}
+
+static PassRegistration<BenchmarkBatchDispatchesPass> pass(
+ "test-iree-hal-benchmark-batch-dispatches-2-times",
+ "Test pass used for benchmarking batch dispatches analysis",
+ [] { return std::make_unique<BenchmarkBatchDispatchesPass>(2); });
+
+} // namespace HAL
+} // namespace IREE
+} // namespace iree_compiler
+} // namespace mlir
diff --git a/iree/compiler/Dialect/HAL/Transforms/CMakeLists.txt b/iree/compiler/Dialect/HAL/Transforms/CMakeLists.txt
index 80b4863..2a0d76b 100644
--- a/iree/compiler/Dialect/HAL/Transforms/CMakeLists.txt
+++ b/iree/compiler/Dialect/HAL/Transforms/CMakeLists.txt
@@ -16,6 +16,7 @@
HDRS
"Passes.h"
SRCS
+ "BenchmarkBatchDispatches.cpp"
"CSEVariableLoads.cpp"
"ConvertToHAL.cpp"
"IdentifyConstantPools.cpp"
diff --git a/iree/compiler/Dialect/HAL/Transforms/Passes.cpp b/iree/compiler/Dialect/HAL/Transforms/Passes.cpp
index ce10e6b..29f64a4 100644
--- a/iree/compiler/Dialect/HAL/Transforms/Passes.cpp
+++ b/iree/compiler/Dialect/HAL/Transforms/Passes.cpp
@@ -41,6 +41,12 @@
llvm::cl::init(true)};
};
+static llvm::cl::opt<unsigned> benchmarkDispatchRepeatCount{
+ "iree-hal-benchmark-dispatch-repeat-count",
+ llvm::cl::desc(
+ "The number of times to repeat each hal.command_buffer.dispatch op."),
+ llvm::cl::init(1)};
+
} // namespace
void buildHALTransformPassPipeline(OpPassManager &passManager,
@@ -110,6 +116,10 @@
// Inline hal.device.switch ops and memoize their queries such that we can
// better CSE/fold dispatch logic.
passManager.addNestedPass<FuncOp>(createInlineDeviceSwitchesPass());
+ if (benchmarkDispatchRepeatCount != 1) {
+ passManager.addNestedPass<FuncOp>(
+ createBenchmarkBatchDispatchesPass(benchmarkDispatchRepeatCount));
+ }
passManager.addPass(createLowerAffinePass());
passManager.addPass(createMemoizeDeviceQueriesPass());
passManager.addNestedPass<FuncOp>(createCanonicalizerPass());
diff --git a/iree/compiler/Dialect/HAL/Transforms/Passes.h b/iree/compiler/Dialect/HAL/Transforms/Passes.h
index ef536c5..99c10f3 100644
--- a/iree/compiler/Dialect/HAL/Transforms/Passes.h
+++ b/iree/compiler/Dialect/HAL/Transforms/Passes.h
@@ -130,6 +130,11 @@
// TODO(#1124): replace with memory side effects once supported upstream.
std::unique_ptr<OperationPass<FuncOp>> createCSEVariableLoadsPass();
+// Repeats dispatches `iree-hal-repeat-dispatch-num` times, which is 1 by
+// default.
+std::unique_ptr<OperationPass<FuncOp>> createBenchmarkBatchDispatchesPass(
+ unsigned repeatCount);
+
//===----------------------------------------------------------------------===//
// Register all Passes
//===----------------------------------------------------------------------===//
@@ -138,6 +143,7 @@
registerHALTransformPassPipeline();
auto executableOptions = getTargetOptionsFromFlags();
createConvertToHALPass();
+ createBenchmarkBatchDispatchesPass(/*repeatCount=*/1);
createInlineDeviceSwitchesPass();
createMemoizeDeviceQueriesPass();
createMaterializeInterfacesPass(executableOptions);
diff --git a/iree/compiler/Dialect/HAL/Transforms/test/benchmark_batch_dispatches.mlir b/iree/compiler/Dialect/HAL/Transforms/test/benchmark_batch_dispatches.mlir
new file mode 100644
index 0000000..45b479e
--- /dev/null
+++ b/iree/compiler/Dialect/HAL/Transforms/test/benchmark_batch_dispatches.mlir
@@ -0,0 +1,24 @@
+// RUN: iree-opt -split-input-file -test-iree-hal-benchmark-batch-dispatches-2-times %s | IreeFileCheck %s
+
+hal.variable @_executable_0 : !hal.executable
+func @multiple_reads_no_writes() {
+ %0 = hal.variable.load @_executable_0 : !hal.executable
+ %1 = hal.variable.load @_executable_0 : !hal.executable
+ %2 = hal.variable.load @_executable_0 : !hal.executable
+
+ %c1 = constant 1 : index
+ %dev = hal.ex.shared_device : !hal.device
+ %cmd = hal.command_buffer.create %dev, "OneShot", "Transfer|Dispatch" : !hal.command_buffer
+ hal.command_buffer.begin %cmd
+ hal.command_buffer.dispatch %cmd, %0, entry_point = 0, workgroup_xyz = [%c1, %c1, %c1]
+ hal.command_buffer.dispatch %cmd, %1, entry_point = 0, workgroup_xyz = [%c1, %c1, %c1]
+ hal.command_buffer.dispatch %cmd, %2, entry_point = 0, workgroup_xyz = [%c1, %c1, %c1]
+ hal.command_buffer.end %cmd
+ return
+}
+// CHECK: hal.command_buffer.dispatch %cmd, %0, entry_point = 0, workgroup_xyz = [%c1, %c1, %c1]
+// CHECK: hal.command_buffer.dispatch %cmd, %0, entry_point = 0, workgroup_xyz = [%c1, %c1, %c1]
+// CHECK: hal.command_buffer.dispatch %cmd, %1, entry_point = 0, workgroup_xyz = [%c1, %c1, %c1]
+// CHECK: hal.command_buffer.dispatch %cmd, %1, entry_point = 0, workgroup_xyz = [%c1, %c1, %c1]
+// CHECK: hal.command_buffer.dispatch %cmd, %2, entry_point = 0, workgroup_xyz = [%c1, %c1, %c1]
+// CHECK: hal.command_buffer.dispatch %cmd, %2, entry_point = 0, workgroup_xyz = [%c1, %c1, %c1]
diff --git a/iree/tools/iree-benchmark-module-main.cc b/iree/tools/iree-benchmark-module-main.cc
index b567c1d..058dc2b 100644
--- a/iree/tools/iree-benchmark-module-main.cc
+++ b/iree/tools/iree-benchmark-module-main.cc
@@ -31,6 +31,13 @@
"File containing the module to load that contains the entry "
"function. Defaults to stdin.");
+// TODO(hanchung): Extract the batch size using
+// iree_vm_function_reflection_attr.
+ABSL_FLAG(
+ int, batch_size, 1,
+ "The number of batch size, which is expected to match "
+ "iree-hal-benchmark-dispatch-repeat-count when translating the module");
+
ABSL_FLAG(std::string, entry_function, "",
"Name of a function contained in the module specified by module_file "
"to run. If this is not set, all the exported functions will be "
@@ -58,15 +65,16 @@
namespace {
static void BenchmarkFunction(
- const std::string& benchmark_name, iree_vm_context_t* context,
- iree_vm_function_t function, iree_vm_list_t* inputs,
+ const std::string& benchmark_name, int batch_size,
+ iree_vm_context_t* context, iree_vm_function_t function,
+ iree_vm_list_t* inputs,
const std::vector<RawSignatureParser::Description>& output_descs,
benchmark::State& state) {
IREE_TRACE_SCOPE_DYNAMIC(benchmark_name.c_str());
IREE_TRACE_FRAME_MARK();
// Benchmarking loop.
- for (auto _ : state) {
+ while (state.KeepRunningBatch(batch_size)) {
IREE_TRACE_SCOPE0("BenchmarkIteration");
IREE_TRACE_FRAME_MARK_NAMED("Iteration");
vm::ref<iree_vm_list_t> outputs;
@@ -83,13 +91,14 @@
iree_vm_function_t function, iree_vm_list_t* inputs,
const std::vector<RawSignatureParser::Description>& output_descs) {
auto benchmark_name = "BM_" + function_name;
- benchmark::RegisterBenchmark(benchmark_name.c_str(),
- [benchmark_name, context, function, inputs,
- output_descs](benchmark::State& state) -> void {
- BenchmarkFunction(benchmark_name, context,
- function, inputs,
- output_descs, state);
- })
+ int batch_size = absl::GetFlag(FLAGS_batch_size);
+ benchmark::RegisterBenchmark(
+ benchmark_name.c_str(),
+ [benchmark_name, batch_size, context, function, inputs,
+ output_descs](benchmark::State& state) -> void {
+ BenchmarkFunction(benchmark_name, batch_size, context, function, inputs,
+ output_descs, state);
+ })
// By default only the main thread is included in CPU time. Include all
// the threads instead.
->MeasureProcessCPUTime()