Fix device_allocator option in new benchmark suite (#12223)

Co-authored-by: Geoffrey Martin-Noble <gcmn@google.com>
diff --git a/build_tools/benchmarks/common/benchmark_definition.py b/build_tools/benchmarks/common/benchmark_definition.py
index f38ad5b..b3c91f0 100644
--- a/build_tools/benchmarks/common/benchmark_definition.py
+++ b/build_tools/benchmarks/common/benchmark_definition.py
@@ -138,7 +138,6 @@
     repetitions = 10
 
   cmd = [
-      "--device_allocator=caching",
       "--time_unit=ns",
       "--benchmark_format=json",
       "--benchmark_out_format=json",
diff --git a/build_tools/python/benchmark_suites/iree/module_execution_configs.py b/build_tools/python/benchmark_suites/iree/module_execution_configs.py
index d1b5dbb..fe81e97 100644
--- a/build_tools/python/benchmark_suites/iree/module_execution_configs.py
+++ b/build_tools/python/benchmark_suites/iree/module_execution_configs.py
@@ -5,35 +5,54 @@
 # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
 """Defines ModuleExecutionConfig for benchmarks."""
 
+from typing import List, Optional, Sequence
+
 from e2e_test_framework.definitions import iree_definitions
 from e2e_test_framework import unique_ids
 
-ELF_LOCAL_SYNC_CONFIG = iree_definitions.ModuleExecutionConfig(
+
+def _with_caching_allocator(
+    id: str,
+    tags: List[str],
+    loader: iree_definitions.RuntimeLoader,
+    driver: iree_definitions.RuntimeDriver,
+    extra_flags: Optional[Sequence[str]] = None
+) -> iree_definitions.ModuleExecutionConfig:
+  extra_flags = [] if extra_flags is None else list(extra_flags)
+  return iree_definitions.ModuleExecutionConfig(
+      id=id,
+      tags=tags,
+      loader=loader,
+      driver=driver,
+      extra_flags=["--device_allocator=caching"] + extra_flags)
+
+
+ELF_LOCAL_SYNC_CONFIG = _with_caching_allocator(
     id=unique_ids.IREE_MODULE_EXECUTION_CONFIG_LOCAL_SYNC,
     tags=["full-inference", "default-flags"],
     loader=iree_definitions.RuntimeLoader.EMBEDDED_ELF,
     driver=iree_definitions.RuntimeDriver.LOCAL_SYNC)
 
-CUDA_CONFIG = iree_definitions.ModuleExecutionConfig(
+CUDA_CONFIG = _with_caching_allocator(
     id=unique_ids.IREE_MODULE_EXECUTION_CONFIG_CUDA,
     tags=["full-inference", "default-flags"],
     loader=iree_definitions.RuntimeLoader.NONE,
     driver=iree_definitions.RuntimeDriver.CUDA)
 
-VULKAN_CONFIG = iree_definitions.ModuleExecutionConfig(
+VULKAN_CONFIG = _with_caching_allocator(
     id=unique_ids.IREE_MODULE_EXECUTION_CONFIG_VULKAN,
     tags=["full-inference", "default-flags"],
     loader=iree_definitions.RuntimeLoader.NONE,
     driver=iree_definitions.RuntimeDriver.VULKAN)
 
-VULKAN_BATCH_SIZE_16_CONFIG = iree_definitions.ModuleExecutionConfig(
+VULKAN_BATCH_SIZE_16_CONFIG = _with_caching_allocator(
     id=unique_ids.IREE_MODULE_EXECUTION_CONFIG_VULKAN_BATCH_SIZE_16,
     tags=["full-inference", "experimental-flags"],
     loader=iree_definitions.RuntimeLoader.NONE,
     driver=iree_definitions.RuntimeDriver.VULKAN,
     extra_flags=["--batch_size=16"])
 
-VULKAN_BATCH_SIZE_32_CONFIG = iree_definitions.ModuleExecutionConfig(
+VULKAN_BATCH_SIZE_32_CONFIG = _with_caching_allocator(
     id=unique_ids.IREE_MODULE_EXECUTION_CONFIG_VULKAN_BATCH_SIZE_32,
     tags=["full-inference", "experimental-flags"],
     loader=iree_definitions.RuntimeLoader.NONE,
@@ -43,7 +62,7 @@
 
 def get_elf_local_task_config(thread_num: int):
   config_id = f"{unique_ids.IREE_MODULE_EXECUTION_CONFIG_LOCAL_TASK_BASE}-{thread_num}"
-  return iree_definitions.ModuleExecutionConfig(
+  return _with_caching_allocator(
       id=config_id,
       tags=[f"{thread_num}-thread", "full-inference", "default-flags"],
       loader=iree_definitions.RuntimeLoader.EMBEDDED_ELF,
@@ -53,7 +72,7 @@
 
 def get_vmvx_local_task_config(thread_num: int):
   config_id = f"{unique_ids.IREE_MODULE_EXECUTION_CONFIG_VMVX_LOCAL_TASK_BASE}-{thread_num}"
-  return iree_definitions.ModuleExecutionConfig(
+  return _with_caching_allocator(
       id=config_id,
       tags=[f"{thread_num}-thread", "full-inference", "default-flags"],
       loader=iree_definitions.RuntimeLoader.VMVX_MODULE,
diff --git a/build_tools/python/e2e_model_tests/cmake_generator.py b/build_tools/python/e2e_model_tests/cmake_generator.py
index baf5e6c..1345c92 100644
--- a/build_tools/python/e2e_model_tests/cmake_generator.py
+++ b/build_tools/python/e2e_model_tests/cmake_generator.py
@@ -26,6 +26,11 @@
     runner_args = run_module_utils.build_run_flags_for_model(
         model=model,
         model_input_data=test_config.input_data) + test_config.extra_test_flags
+    # TODO(#11136): Currently the DRIVER is a separate field in the CMake rule (
+    # and has effect on test labels). Rules should be generated in another way
+    # to avoid that. Generates the flags without the driver for now.
+    runner_args += run_module_utils.build_run_flags_for_execution_config(
+        test_config.execution_config, with_driver=False)
     cmake_rule = cmake_builder.rules.build_iree_benchmark_suite_module_test(
         target_name=test_config.name,
         model=f"{model.id}_{model.name}",
diff --git a/build_tools/python/e2e_model_tests/run_module_utils.py b/build_tools/python/e2e_model_tests/run_module_utils.py
index 7e3827e..bd88c08 100644
--- a/build_tools/python/e2e_model_tests/run_module_utils.py
+++ b/build_tools/python/e2e_model_tests/run_module_utils.py
@@ -26,15 +26,26 @@
 
 def build_run_flags_for_execution_config(
     module_execution_config: ModuleExecutionConfig,
-    gpu_id: str = "0") -> List[str]:
-  """Returns the IREE run module flags of the execution config."""
+    gpu_id: str = "0",
+    with_driver: bool = True) -> List[str]:
+  """Returns the IREE run module flags of the execution config.
 
-  run_flags = list(module_execution_config.extra_flags)
-  driver = module_execution_config.driver
-  if driver == RuntimeDriver.CUDA:
-    run_flags.append(f"--device=cuda://{gpu_id}")
-  else:
-    run_flags.append(f"--device={driver.value}")
+  Args:
+    module_execution_config: execution config.
+    gpu_id: target gpu id, if runs on GPUs.
+    with_driver: populate the driver flags if true. False can be used for
+      generating flags for some CMake rules with a separate DRIVER arg.
+  Returns:
+    List of flags.
+  """
+
+  run_flags = module_execution_config.extra_flags.copy()
+  if with_driver:
+    driver = module_execution_config.driver
+    if driver == RuntimeDriver.CUDA:
+      run_flags.append(f"--device=cuda://{gpu_id}")
+    else:
+      run_flags.append(f"--device={driver.value}")
   return run_flags
 
 
diff --git a/build_tools/python/e2e_model_tests/run_module_utils_test.py b/build_tools/python/e2e_model_tests/run_module_utils_test.py
index dccb98f..5a255da 100644
--- a/build_tools/python/e2e_model_tests/run_module_utils_test.py
+++ b/build_tools/python/e2e_model_tests/run_module_utils_test.py
@@ -55,6 +55,19 @@
 
     self.assertEqual(flags, ["--device=cuda://3"])
 
+  def test_build_run_flags_for_execution_config_without_driver(self):
+    execution_config = iree_definitions.ModuleExecutionConfig(
+        id="123",
+        tags=["test"],
+        loader=iree_definitions.RuntimeLoader.EMBEDDED_ELF,
+        driver=iree_definitions.RuntimeDriver.LOCAL_TASK,
+        extra_flags=["--task=10"])
+
+    flags = run_module_utils.build_run_flags_for_execution_config(
+        execution_config, with_driver=False)
+
+    self.assertEqual(flags, ["--task=10"])
+
   def test_build_linux_wrapper_cmds_for_device_spec(self):
     device_spec = common_definitions.DeviceSpec(
         id="abc",
diff --git a/tests/e2e/models/generated_e2e_model_tests.cmake b/tests/e2e/models/generated_e2e_model_tests.cmake
index ea0a24f..f33f379 100644
--- a/tests/e2e/models/generated_e2e_model_tests.cmake
+++ b/tests/e2e/models/generated_e2e_model_tests.cmake
@@ -32,6 +32,7 @@
   RUNNER_ARGS
     "--function=main"
     "--input=1x224x224x3xf32=0"
+    "--device_allocator=caching"
   UNSUPPORTED_PLATFORMS
     "riscv32-Linux"
     "android-arm64-v8a"
@@ -49,6 +50,7 @@
   RUNNER_ARGS
     "--function=main"
     "--input=1x224x224x3xui8=0"
+    "--device_allocator=caching"
   UNSUPPORTED_PLATFORMS
     "android-arm64-v8a"
 )
@@ -66,6 +68,7 @@
     "--function=main"
     "--input=1x257x257x3xf32=0"
     "--expected_f32_threshold=0.001"
+    "--device_allocator=caching"
   UNSUPPORTED_PLATFORMS
     "riscv32-Linux"
 )
@@ -82,6 +85,7 @@
   RUNNER_ARGS
     "--function=main"
     "--input=1x96x96x1xi8=0"
+    "--device_allocator=caching"
   UNSUPPORTED_PLATFORMS
     "android-arm64-v8a"
 )