Rework iree-run-mlir to operate against the IREE compiler C API. (#12715)

This removes some unused functionality while sprinkling in the seeds of
something better. Future changes can go deeper on inferring compiler
configuration from available devices and such (will need some reworking
of run_module.c - I erred on the side of simplicity/reuse today). This
isn't meant to be an example of the best way to build an online compiler
but instead serve as an in-tree test of something that approximates one.
diff --git a/compiler/src/iree/compiler/API/Internal/Embed.cpp b/compiler/src/iree/compiler/API/Internal/Embed.cpp
index d95c655..0b24526 100644
--- a/compiler/src/iree/compiler/API/Internal/Embed.cpp
+++ b/compiler/src/iree/compiler/API/Internal/Embed.cpp
@@ -578,16 +578,17 @@
               cSeverity = IREE_COMPILER_DIAGNOSTIC_SEVERITY_NOTE;
               break;
             case DiagnosticSeverity::Warning:
-              cSeverity = IREE_COMPILER_DIAGNOSTIC_SEVERITY_NOTE;
+              cSeverity = IREE_COMPILER_DIAGNOSTIC_SEVERITY_WARNING;
               break;
             case DiagnosticSeverity::Error:
-              cSeverity = IREE_COMPILER_DIAGNOSTIC_SEVERITY_NOTE;
+              cSeverity = IREE_COMPILER_DIAGNOSTIC_SEVERITY_ERROR;
               break;
             case DiagnosticSeverity::Remark:
-              cSeverity = IREE_COMPILER_DIAGNOSTIC_SEVERITY_NOTE;
+              cSeverity = IREE_COMPILER_DIAGNOSTIC_SEVERITY_REMARK;
               break;
             default:
               cSeverity = IREE_COMPILER_DIAGNOSTIC_SEVERITY_ERROR;
+              break;
           }
           diagnosticCallback(cSeverity, message.data(), message.size(),
                              diagnosticCallbackUserData);
diff --git a/runtime/src/iree/base/string_view.c b/runtime/src/iree/base/string_view.c
index 92c9be1..d198329 100644
--- a/runtime/src/iree/base/string_view.c
+++ b/runtime/src/iree/base/string_view.c
@@ -210,14 +210,14 @@
                                                 char split_char,
                                                 iree_string_view_t* out_lhs,
                                                 iree_string_view_t* out_rhs) {
-  *out_lhs = iree_string_view_empty();
-  *out_rhs = iree_string_view_empty();
+  if (out_lhs) *out_lhs = iree_string_view_empty();
+  if (out_rhs) *out_rhs = iree_string_view_empty();
   if (!value.data || !value.size) {
     return -1;
   }
   const void* first_ptr = memchr(value.data, split_char, value.size);
   if (!first_ptr) {
-    *out_lhs = value;
+    if (out_lhs) *out_lhs = value;
     return -1;
   }
   intptr_t offset = (intptr_t)((const char*)(first_ptr)-value.data);
diff --git a/runtime/src/iree/base/string_view_test.cc b/runtime/src/iree/base/string_view_test.cc
index d41fa3c..a2dbf7c 100644
--- a/runtime/src/iree/base/string_view_test.cc
+++ b/runtime/src/iree/base/string_view_test.cc
@@ -357,8 +357,8 @@
   auto split =
       [](const char* value,
          char split_char) -> std::tuple<intptr_t, std::string, std::string> {
-    iree_string_view_t lhs;
-    iree_string_view_t rhs;
+    iree_string_view_t lhs = iree_string_view_empty();
+    iree_string_view_t rhs = iree_string_view_empty();
     intptr_t index = iree_string_view_split(iree_make_cstring_view(value),
                                             split_char, &lhs, &rhs);
     return std::make_tuple(index, ToString(lhs), ToString(rhs));
@@ -374,6 +374,61 @@
   EXPECT_EQ(split("axbxc", 'x'), std::make_tuple(1, "a", "bxc"));
 }
 
+// Tests that partial returns from iree_string_view_split (only LHS or RHS) work
+// as expected even when no storage is provided.
+TEST(StringViewTest, SplitLHSOnly) {
+  auto split_lhs = [](const char* value,
+                      char split_char) -> std::tuple<intptr_t, std::string> {
+    iree_string_view_t lhs = iree_string_view_empty();
+    intptr_t index = iree_string_view_split(iree_make_cstring_view(value),
+                                            split_char, &lhs, nullptr);
+    return std::make_tuple(index, ToString(lhs));
+  };
+  EXPECT_EQ(split_lhs("", 'x'), std::make_tuple(-1, ""));
+  EXPECT_EQ(split_lhs(" ", 'x'), std::make_tuple(-1, " "));
+  EXPECT_EQ(split_lhs("x", 'x'), std::make_tuple(0, ""));
+  EXPECT_EQ(split_lhs(" x ", 'x'), std::make_tuple(1, " "));
+  EXPECT_EQ(split_lhs("axb", 'x'), std::make_tuple(1, "a"));
+  EXPECT_EQ(split_lhs("axxxb", 'x'), std::make_tuple(1, "a"));
+  EXPECT_EQ(split_lhs("ax", 'x'), std::make_tuple(1, "a"));
+  EXPECT_EQ(split_lhs("xb", 'x'), std::make_tuple(0, ""));
+  EXPECT_EQ(split_lhs("axbxc", 'x'), std::make_tuple(1, "a"));
+}
+TEST(StringViewTest, SplitRHSOnly) {
+  auto split_rhs = [](const char* value,
+                      char split_char) -> std::tuple<intptr_t, std::string> {
+    iree_string_view_t rhs = iree_string_view_empty();
+    intptr_t index = iree_string_view_split(iree_make_cstring_view(value),
+                                            split_char, nullptr, &rhs);
+    return std::make_tuple(index, ToString(rhs));
+  };
+  EXPECT_EQ(split_rhs("", 'x'), std::make_tuple(-1, ""));
+  EXPECT_EQ(split_rhs(" ", 'x'), std::make_tuple(-1, ""));
+  EXPECT_EQ(split_rhs("x", 'x'), std::make_tuple(0, ""));
+  EXPECT_EQ(split_rhs(" x ", 'x'), std::make_tuple(1, " "));
+  EXPECT_EQ(split_rhs("axb", 'x'), std::make_tuple(1, "b"));
+  EXPECT_EQ(split_rhs("axxxb", 'x'), std::make_tuple(1, "xxb"));
+  EXPECT_EQ(split_rhs("ax", 'x'), std::make_tuple(1, ""));
+  EXPECT_EQ(split_rhs("xb", 'x'), std::make_tuple(0, "b"));
+  EXPECT_EQ(split_rhs("axbxc", 'x'), std::make_tuple(1, "bxc"));
+}
+TEST(StringViewTest, SplitReturnOnly) {
+  // This is effectively a find but with extra steps.
+  auto split_return = [](const char* value, char split_char) -> intptr_t {
+    return iree_string_view_split(iree_make_cstring_view(value), split_char,
+                                  nullptr, nullptr);
+  };
+  EXPECT_EQ(split_return("", 'x'), -1);
+  EXPECT_EQ(split_return(" ", 'x'), -1);
+  EXPECT_EQ(split_return("x", 'x'), 0);
+  EXPECT_EQ(split_return(" x ", 'x'), 1);
+  EXPECT_EQ(split_return("axb", 'x'), 1);
+  EXPECT_EQ(split_return("axxxb", 'x'), 1);
+  EXPECT_EQ(split_return("ax", 'x'), 1);
+  EXPECT_EQ(split_return("xb", 'x'), 0);
+  EXPECT_EQ(split_return("axbxc", 'x'), 1);
+}
+
 TEST(StringViewTest, ReplaceChar) {
   auto replace_char = [](const char* value, char old_char, char new_char) {
     std::string value_clone(value);
diff --git a/runtime/src/iree/hal/drivers/cuda/registration/driver_module.c b/runtime/src/iree/hal/drivers/cuda/registration/driver_module.c
index dfd625f..ad5160d 100644
--- a/runtime/src/iree/hal/drivers/cuda/registration/driver_module.c
+++ b/runtime/src/iree/hal/drivers/cuda/registration/driver_module.c
@@ -21,13 +21,14 @@
     "Use CUDA streams for executing command buffers (instead of graphs).");
 
 IREE_FLAG(bool, cuda_allow_inline_execution, false,
-          "Allow command buffers to execute inline against CUDA streams when "
+          "Allow command buffers to execute inline against CUDA streams when\n"
           "possible.");
 
-IREE_FLAG(bool, cuda_tracing, true,
-          "Enables tracing of stream events when Tracy instrumentation is "
-          "enabled. Severely impacts benchmark timings and should only be used "
-          "when analyzing dispatch timings.");
+IREE_FLAG(
+    bool, cuda_tracing, true,
+    "Enables tracing of stream events when Tracy instrumentation is enabled.\n"
+    "Severely impacts benchmark timings and should only be used when\n"
+    "analyzing dispatch timings.");
 
 IREE_FLAG(int32_t, cuda_default_index, 0, "Index of the default CUDA device.");
 
diff --git a/runtime/src/iree/hal/drivers/vulkan/registration/driver_module.cc b/runtime/src/iree/hal/drivers/vulkan/registration/driver_module.cc
index 5427fac..74e83de 100644
--- a/runtime/src/iree/hal/drivers/vulkan/registration/driver_module.cc
+++ b/runtime/src/iree/hal/drivers/vulkan/registration/driver_module.cc
@@ -39,7 +39,7 @@
     "Use a dedicated queue with VK_QUEUE_COMPUTE_BIT for dispatch workloads.");
 IREE_FLAG(
     int64_t, vulkan_large_heap_block_size, 0,
-    "Preferred allocator block size for large allocations in bytes. Sets the "
+    "Preferred allocator block size for large allocations in bytes. Sets the\n"
     "minimum bound on memory consumption.");
 
 static iree_status_t iree_hal_vulkan_create_driver_with_flags(
diff --git a/runtime/src/iree/tooling/BUILD.bazel b/runtime/src/iree/tooling/BUILD.bazel
index 1254217..b8fb852 100644
--- a/runtime/src/iree/tooling/BUILD.bazel
+++ b/runtime/src/iree/tooling/BUILD.bazel
@@ -150,6 +150,26 @@
     ],
 )
 
+iree_runtime_cc_library(
+    name = "run_module",
+    srcs = ["run_module.c"],
+    hdrs = ["run_module.h"],
+    deps = [
+        ":comparison",
+        ":context_util",
+        ":device_util",
+        ":instrument_util",
+        ":vm_util",
+        "//runtime/src/iree/base",
+        "//runtime/src/iree/base:tracing",
+        "//runtime/src/iree/base/internal:flags",
+        "//runtime/src/iree/hal",
+        "//runtime/src/iree/modules/hal:types",
+        "//runtime/src/iree/vm",
+        "//runtime/src/iree/vm/bytecode:module",
+    ],
+)
+
 # TODO(benvanik): fold these into iree/runtime and use that instead.
 iree_runtime_cc_library(
     name = "vm_util",
diff --git a/runtime/src/iree/tooling/CMakeLists.txt b/runtime/src/iree/tooling/CMakeLists.txt
index 130f6dd..26052a1 100644
--- a/runtime/src/iree/tooling/CMakeLists.txt
+++ b/runtime/src/iree/tooling/CMakeLists.txt
@@ -172,6 +172,29 @@
 
 iree_cc_library(
   NAME
+    run_module
+  HDRS
+    "run_module.h"
+  SRCS
+    "run_module.c"
+  DEPS
+    ::comparison
+    ::context_util
+    ::device_util
+    ::instrument_util
+    ::vm_util
+    iree::base
+    iree::base::internal::flags
+    iree::base::tracing
+    iree::hal
+    iree::modules::hal::types
+    iree::vm
+    iree::vm::bytecode::module
+  PUBLIC
+)
+
+iree_cc_library(
+  NAME
     vm_util
   HDRS
     "vm_util.h"
diff --git a/runtime/src/iree/tooling/comparison.h b/runtime/src/iree/tooling/comparison.h
index c7e5acf..be9c8d5 100644
--- a/runtime/src/iree/tooling/comparison.h
+++ b/runtime/src/iree/tooling/comparison.h
@@ -7,9 +7,14 @@
 #ifndef IREE_TOOLING_COMPARISON_H_
 #define IREE_TOOLING_COMPARISON_H_
 
+#include <stdio.h>
+
 #include "iree/base/api.h"
 #include "iree/vm/api.h"
-#include "stdio.h"
+
+#ifdef __cplusplus
+extern "C" {
+#endif  // __cplusplus
 
 // Compares expected vs actual results and appends to |builder|.
 // Returns true if all values match and false otherwise.
@@ -27,4 +32,8 @@
                                         iree_allocator_t host_allocator,
                                         FILE* file);
 
+#ifdef __cplusplus
+}  // extern "C"
+#endif  // __cplusplus
+
 #endif  // IREE_TOOLING_COMPARISON_H_
diff --git a/runtime/src/iree/tooling/device_util.c b/runtime/src/iree/tooling/device_util.c
index e16fe03..fb2e37a 100644
--- a/runtime/src/iree/tooling/device_util.c
+++ b/runtime/src/iree/tooling/device_util.c
@@ -374,14 +374,15 @@
 
 IREE_FLAG(
     string, device_profiling_mode, "",
-    "HAL device profiling mode (one of ['queue', 'dispatch', 'executable']) "
-    "or empty to disable profiling. HAL implementations may require additional "
-    "flags in order to configure profiling support on "
-    "their devices.");
+    "HAL device profiling mode (one of ['queue', 'dispatch', 'executable'])\n"
+    "or empty to disable profiling. HAL implementations may require\n"
+    "additional flags in order to configure profiling support on their\n"
+    "devices.");
 IREE_FLAG(
     string, device_profiling_file, "",
-    "Optional file path/prefix for profiling file output. Some implementations "
-    "may require a file name in order to capture profiling information.");
+    "Optional file path/prefix for profiling file output. Some\n"
+    "implementations may require a file name in order to capture profiling\n"
+    "information.");
 
 iree_status_t iree_hal_begin_profiling_from_flags(iree_hal_device_t* device) {
   if (!device) return iree_ok_status();
diff --git a/runtime/src/iree/tooling/run_module.c b/runtime/src/iree/tooling/run_module.c
new file mode 100644
index 0000000..0d416ab
--- /dev/null
+++ b/runtime/src/iree/tooling/run_module.c
@@ -0,0 +1,365 @@
+// Copyright 2023 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "iree/tooling/run_module.h"
+
+#include "iree/base/api.h"
+#include "iree/base/internal/flags.h"
+#include "iree/base/tracing.h"
+#include "iree/hal/api.h"
+#include "iree/modules/hal/types.h"
+#include "iree/tooling/comparison.h"
+#include "iree/tooling/context_util.h"
+#include "iree/tooling/device_util.h"
+#include "iree/tooling/instrument_util.h"
+#include "iree/tooling/vm_util.h"
+#include "iree/vm/api.h"
+#include "iree/vm/bytecode/module.h"
+
+IREE_FLAG(string, function, "",
+          "Name of a function contained in the module specified by --module= "
+          "to run.");
+
+IREE_FLAG_LIST(
+    string, input,
+    "An input (a) value or (b) buffer of the format:\n"
+    "  (a) scalar value\n"
+    "     value\n"
+    "     e.g.: --input=\"3.14\"\n"
+    "  (b) buffer:\n"
+    "     [shape]xtype=[value]\n"
+    "     e.g.: --input=\"2x2xi32=1 2 3 4\"\n"
+    "Optionally, brackets may be used to separate the element values:\n"
+    "  2x2xi32=[[1 2][3 4]]\n"
+    "Raw binary files can be read to provide buffer contents:\n"
+    "  2x2xi32=@some/file.bin\n"
+    "\n"
+    "Numpy npy files from numpy.save can be read to provide 1+ values:\n"
+    "  @some.npy\n"
+    "\n"
+    "Each occurrence of the flag indicates an input in the order they were\n"
+    "specified on the command line.");
+
+IREE_FLAG_LIST(
+    string, output,
+    "Specifies how to handle an output from the invocation:\n"
+    "  `` (empty): ignore output\n"
+    "     e.g.: --output=\n"
+    "  `-`: print textual form to stdout\n"
+    "     e.g.: --output=-\n"
+    "  `@file.npy`: create/overwrite a numpy npy file and write buffer view\n"
+    "     e.g.: --output=@file.npy\n"
+    "  `+file.npy`: create/append a numpy npy file and write buffer view\n"
+    "     e.g.: --output=+file.npy\n"
+    "\n"
+    "Numpy npy files can be read in Python using numpy.load, for example an\n"
+    "invocation producing two outputs can be concatenated as:\n"
+    "    --output=@file.npy --output=+file.npy\n"
+    "And then loaded in Python by reading from the same file:\n"
+    "  with open('file.npy', 'rb') as f:\n"
+    "    print(numpy.load(f))\n"
+    "    print(numpy.load(f))\n"
+    "\n"
+    "Each occurrence of the flag indicates an output in the order they were\n"
+    "specified on the command line.");
+
+IREE_FLAG_LIST(
+    string, expected_output,
+    "An expected function output following the same format as `--input=`.\n"
+    "When present the results of the invocation will be compared against\n"
+    "these values and the tool will return non-zero if any differ. If the\n"
+    "value of a particular output is not of interest provide `(ignored)`.");
+
+IREE_FLAG(
+    int32_t, output_max_element_count, 1024,
+    "Prints up to the maximum number of elements of output tensors and elides\n"
+    "the remainder.");
+
+IREE_FLAG(bool, print_statistics, false,
+          "Prints runtime statistics to stderr on exit.");
+
+static iree_status_t iree_tooling_process_outputs(
+    iree_vm_list_t* outputs, iree_allocator_t host_allocator,
+    int* out_exit_code);
+
+static iree_status_t iree_tooling_create_run_context(
+    iree_vm_instance_t* instance, iree_string_view_t default_device_uri,
+    iree_const_byte_span_t module_contents, iree_allocator_t host_allocator,
+    iree_vm_context_t** out_context, iree_vm_function_t* out_function,
+    iree_hal_device_t** out_device,
+    iree_hal_allocator_t** out_device_allocator) {
+  // Load all modules specified by --module= flags.
+  iree_tooling_module_list_t module_list;
+  iree_tooling_module_list_initialize(&module_list);
+  IREE_RETURN_IF_ERROR(iree_tooling_load_modules_from_flags(
+                           instance, host_allocator, &module_list),
+                       "loading modules and dependencies");
+
+  // Load the optional bytecode module from the provided flatbuffer data.
+  // Note that we do this after all other --module= flags are processed so that
+  // we ensure any dependent types are registered with the instance.
+  iree_status_t status = iree_ok_status();
+  if (!iree_const_byte_span_is_empty(module_contents)) {
+    iree_vm_module_t* module = NULL;
+    status = iree_status_annotate_f(
+        iree_vm_bytecode_module_create(instance, module_contents,
+                                       iree_allocator_null(), host_allocator,
+                                       &module),
+        "loading custom bytecode module from memory");
+    if (iree_status_is_ok(status)) {
+      status = iree_tooling_module_list_push_back(&module_list, module);
+    }
+  }
+  if (!iree_status_is_ok(status)) {
+    iree_tooling_module_list_reset(&module_list);
+    return status;
+  }
+
+  // There's concept of a "main module" in the VM but here we use it to allow
+  // for a shorthand --function= flag that doesn't need the module name.
+  iree_vm_module_t* main_module = iree_tooling_module_list_back(&module_list);
+  if (!main_module) {
+    status = iree_make_status(
+        IREE_STATUS_INVALID_ARGUMENT,
+        "no user module specified; use --module=file.vmfb to load from a "
+        "file or --module=- to load from stdin");
+  }
+
+  // Create the VM context with all of the modules. Dependent modules will be
+  // loaded (like the HAL) and special things like the HAL device and allocator
+  // are returned for convenience. Note that not all programs need the HAL.
+  iree_vm_context_t* context = NULL;
+  iree_hal_device_t* device = NULL;
+  iree_hal_allocator_t* device_allocator = NULL;
+  if (iree_status_is_ok(status)) {
+    status = iree_status_annotate_f(
+        iree_tooling_create_context_from_flags(
+            instance, module_list.count, module_list.values, default_device_uri,
+            host_allocator, &context, &device, &device_allocator),
+        "creating VM context");
+  }
+
+  iree_tooling_module_list_reset(&module_list);
+  if (!iree_status_is_ok(status)) {
+    return status;
+  }
+
+  // Choose which function to run - either the one specified in the flag or the
+  // only exported non-internal function.
+  iree_vm_function_t function = {0};
+  if (strlen(FLAG_function) == 0) {
+    status = iree_tooling_find_single_exported_function(main_module, &function);
+  } else {
+    status = iree_status_annotate_f(
+        iree_vm_module_lookup_function_by_name(
+            main_module, IREE_VM_FUNCTION_LINKAGE_EXPORT,
+            iree_make_cstring_view(FLAG_function), &function),
+        "looking up function '%s'", FLAG_function);
+  }
+
+  if (iree_status_is_ok(status)) {
+    *out_context = context;
+    *out_function = function;
+    *out_device = device;
+    *out_device_allocator = device_allocator;
+  } else {
+    iree_vm_context_release(context);
+    iree_hal_allocator_release(device_allocator);
+    iree_hal_device_release(device);
+  }
+  return status;
+}
+
+static iree_status_t iree_tooling_run_function(
+    iree_vm_context_t* context, iree_vm_function_t function,
+    iree_hal_device_t* device, iree_hal_allocator_t* device_allocator,
+    iree_allocator_t host_allocator, int* out_exit_code) {
+  iree_string_view_t function_name = iree_vm_function_name(&function);
+  (void)function_name;
+
+  // Parse --input= values into device buffers.
+  iree_vm_list_t* inputs = NULL;
+  iree_status_t status = iree_status_annotate_f(
+      iree_tooling_parse_to_variant_list(
+          device_allocator, FLAG_input_list().values, FLAG_input_list().count,
+          host_allocator, &inputs),
+      "parsing function inputs");
+
+  // If the function is async add fences so we can invoke it synchronously.
+  iree_hal_fence_t* finish_fence = NULL;
+  if (iree_status_is_ok(status)) {
+    status = iree_status_annotate_f(
+        iree_tooling_append_async_fence_inputs(
+            inputs, &function, device, /*wait_fence=*/NULL, &finish_fence),
+        "setting up async-external fence inputs");
+  }
+
+  // Empty output list to be populated by the invocation.
+  iree_vm_list_t* outputs = NULL;
+  if (iree_status_is_ok(status)) {
+    status = iree_vm_list_create(iree_vm_make_undefined_type_def(), 16,
+                                 host_allocator, &outputs);
+  }
+
+  // TODO(benvanik): move behind a --verbose flag and add more logging.
+  if (iree_status_is_ok(status)) {
+    fprintf(stdout, "EXEC @%.*s\n", (int)function_name.size,
+            function_name.data);
+  }
+
+  // Begin profiling immediate prior to invocation.
+  if (iree_status_is_ok(status)) {
+    status = iree_status_annotate_f(iree_hal_begin_profiling_from_flags(device),
+                                    "beginning device profiling");
+  }
+
+  // Invoke the function with the provided inputs.
+  if (iree_status_is_ok(status)) {
+    status = iree_status_annotate_f(
+        iree_vm_invoke(context, function, IREE_VM_INVOCATION_FLAG_NONE,
+                       /*policy=*/NULL, inputs, outputs, host_allocator),
+        "invoking function '%.*s'", (int)function_name.size,
+        function_name.data);
+  }
+  iree_vm_list_release(inputs);
+
+  // If the function is async we need to wait for it to complete.
+  if (iree_status_is_ok(status) && finish_fence) {
+    IREE_RETURN_IF_ERROR(
+        iree_hal_fence_wait(finish_fence, iree_infinite_timeout()),
+        "waiting on finish fence");
+  }
+
+  // End profiling after waiting for the invocation to finish.
+  if (iree_status_is_ok(status)) {
+    status = iree_status_annotate_f(iree_hal_end_profiling_from_flags(device),
+                                    "ending device profiling");
+  }
+
+  // Grab any instrumentation data present in the context and write it to disk.
+  if (iree_status_is_ok(status)) {
+    status = iree_status_annotate_f(
+        iree_tooling_process_instrument_data(context, host_allocator),
+        "processing instrument data");
+  }
+
+  // Handle either printing/writing the outputs or checking them against
+  // expected values (basic pass/fail testing).
+  if (iree_status_is_ok(status)) {
+    status = iree_status_annotate_f(
+        iree_tooling_process_outputs(outputs, host_allocator, out_exit_code),
+        "processing function outputs");
+  }
+  iree_vm_list_release(outputs);
+
+  return status;
+}
+
+static iree_status_t iree_tooling_process_outputs(
+    iree_vm_list_t* outputs, iree_allocator_t host_allocator,
+    int* out_exit_code) {
+  *out_exit_code = EXIT_SUCCESS;
+
+  // Basic output handling to route to the console or files.
+  if (FLAG_expected_output_list().count == 0) {
+    if (FLAG_output_list().count == 0) {
+      // Print all outputs.
+      return iree_status_annotate_f(
+          iree_tooling_variant_list_fprint(
+              IREE_SV("result"), outputs,
+              (iree_host_size_t)FLAG_output_max_element_count, stdout),
+          "printing results");
+    } else {
+      // Write (or ignore) all outputs.
+      return iree_status_annotate_f(
+          iree_tooling_output_variant_list(
+              outputs, FLAG_output_list().values, FLAG_output_list().count,
+              (iree_host_size_t)FLAG_output_max_element_count, stdout),
+          "outputting results");
+    }
+  }
+
+  // Compare against contents in host-local memory. This avoids polluting
+  // device memory statistics.
+  iree_hal_allocator_t* heap_allocator = NULL;
+  IREE_RETURN_IF_ERROR(iree_hal_allocator_create_heap(
+      IREE_SV("heap"), host_allocator, host_allocator, &heap_allocator));
+
+  // Parse expected list into host-local memory that we can easily access.
+  iree_vm_list_t* expected_list = NULL;
+  iree_status_t status = iree_status_annotate_f(
+      iree_tooling_parse_to_variant_list(
+          heap_allocator, FLAG_expected_output_list().values,
+          FLAG_expected_output_list().count, host_allocator, &expected_list),
+      "parsing expected function outputs");
+
+  // Compare expected vs actual lists and output diffs.
+  if (iree_status_is_ok(status)) {
+    bool did_match = iree_tooling_compare_variant_lists(expected_list, outputs,
+                                                        host_allocator, stdout);
+    if (did_match) {
+      fprintf(
+          stdout,
+          "[SUCCESS] all function outputs matched their expected values.\n");
+    }
+
+    // Exit code 0 if all results matched the expected values.
+    *out_exit_code = did_match ? EXIT_SUCCESS : EXIT_FAILURE;
+  }
+
+  iree_vm_list_release(expected_list);
+  iree_hal_allocator_release(heap_allocator);
+  return status;
+}
+
+iree_status_t iree_tooling_run_module_from_flags(
+    iree_vm_instance_t* instance, iree_allocator_t host_allocator,
+    int* out_exit_code) {
+  return iree_tooling_run_module_with_data(instance, iree_string_view_empty(),
+                                           iree_const_byte_span_empty(),
+                                           host_allocator, out_exit_code);
+}
+
+iree_status_t iree_tooling_run_module_with_data(
+    iree_vm_instance_t* instance, iree_string_view_t default_device_uri,
+    iree_const_byte_span_t module_contents, iree_allocator_t host_allocator,
+    int* out_exit_code) {
+  IREE_TRACE_ZONE_BEGIN(z0);
+
+  // Setup the VM context with all required modules and get the function to run.
+  // This also returns the HAL device and allocator (if any) for I/O handling.
+  iree_vm_context_t* context = NULL;
+  iree_vm_function_t function = {0};
+  iree_hal_device_t* device = NULL;
+  iree_hal_allocator_t* device_allocator = NULL;
+  IREE_RETURN_AND_END_ZONE_IF_ERROR(
+      z0,
+      iree_tooling_create_run_context(instance, default_device_uri,
+                                      module_contents, host_allocator, &context,
+                                      &function, &device, &device_allocator),
+      "creating run context");
+
+  // Parse inputs, run the function, and process outputs.
+  iree_status_t status =
+      iree_tooling_run_function(context, function, device, device_allocator,
+                                host_allocator, out_exit_code);
+
+  // Release the context and all retained resources (variables, constants, etc).
+  iree_vm_context_release(context);
+
+  // Print statistics after we've released the inputs/outputs and the context
+  // which may be holding on to resources like constants/variables.
+  if (device_allocator && FLAG_print_statistics) {
+    IREE_IGNORE_ERROR(
+        iree_hal_allocator_statistics_fprint(stderr, device_allocator));
+  }
+
+  iree_hal_allocator_release(device_allocator);
+  iree_hal_device_release(device);
+
+  IREE_TRACE_ZONE_END(z0);
+  return status;
+}
diff --git a/runtime/src/iree/tooling/run_module.h b/runtime/src/iree/tooling/run_module.h
new file mode 100644
index 0000000..c29f399
--- /dev/null
+++ b/runtime/src/iree/tooling/run_module.h
@@ -0,0 +1,48 @@
+// Copyright 2023 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#ifndef IREE_TOOLING_RUN_MODULE_H_
+#define IREE_TOOLING_RUN_MODULE_H_
+
+#include "iree/base/api.h"
+#include "iree/hal/api.h"
+#include "iree/vm/api.h"
+
+#ifdef __cplusplus
+extern "C" {
+#endif  // __cplusplus
+
+// Runs the module/function specified on the command line with inputs/outputs.
+// Returns the process result code in |out_exit_code| (0 for success).
+//
+// One or more --module= flags can be used to specify all required modules.
+// --function= is used to specify which function in the last module registered
+// is to be executed. One --input= flag per function input can be used to
+// provide function inputs from textual or file sources. One --output= flag per
+// function output can be used to write outputs to a file. Optionally
+// --expected_output= flags can be used to perform basic comparisons against
+// the actual function outputs. See --help for more information.
+iree_status_t iree_tooling_run_module_from_flags(
+    iree_vm_instance_t* instance, iree_allocator_t host_allocator,
+    int* out_exit_code);
+
+// Runs the module/function specified on the command line with the given
+// in-memory main module. Equivalent to iree_tooling_run_module_from_flags but
+// the provided |module_contents| are registered with the context prior to
+// execution.
+//
+// Optionally |default_device_uri| can be used to specify which device should
+// be used if no --device= flag is provided by the user.
+iree_status_t iree_tooling_run_module_with_data(
+    iree_vm_instance_t* instance, iree_string_view_t default_device_uri,
+    iree_const_byte_span_t module_contents, iree_allocator_t host_allocator,
+    int* out_exit_code);
+
+#ifdef __cplusplus
+}  // extern "C"
+#endif  // __cplusplus
+
+#endif  // IREE_TOOLING_RUN_MODULE_H_
diff --git a/tests/e2e/linalg_transform/linalg_transform.mlir b/tests/e2e/linalg_transform/linalg_transform.mlir
index a7fd3d2..7ad4ae2 100644
--- a/tests/e2e/linalg_transform/linalg_transform.mlir
+++ b/tests/e2e/linalg_transform/linalg_transform.mlir
@@ -1,4 +1,4 @@
-// R-UN: iree-run-mlir --iree-hal-target-backends=llvm-cpu \
+// R-UN: iree-run-mlir --Xcompiler,iree-hal-target-backends=llvm-cpu \
 /// Specify the dispatch region formation with the transform dialect.
 // R-UN:   --iree-flow-dispatch-use-transform-dialect=%p/transform_dialect_dispatch_spec.mlir \
 /// Specify the codegen strategy with the transform dialect.
diff --git a/tests/e2e/models/collatz.mlir b/tests/e2e/models/collatz.mlir
index f8e8787..cdf0ac4 100644
--- a/tests/e2e/models/collatz.mlir
+++ b/tests/e2e/models/collatz.mlir
@@ -1,5 +1,5 @@
-// RUN: iree-run-mlir --iree-input-type=mhlo --iree-hal-target-backends=vmvx %s | FileCheck %s
-// RUN: [[ $IREE_VULKAN_DISABLE == 1 ]] || (iree-run-mlir --iree-input-type=mhlo --iree-hal-target-backends=vulkan-spirv %s | FileCheck %s)
+// RUN: iree-run-mlir --Xcompiler,iree-input-type=mhlo --Xcompiler,iree-hal-target-backends=vmvx %s | FileCheck %s
+// RUN: [[ $IREE_VULKAN_DISABLE == 1 ]] || (iree-run-mlir --Xcompiler,iree-input-type=mhlo --Xcompiler,iree-hal-target-backends=vulkan-spirv %s | FileCheck %s)
 
 // CHECK-LABEL: EXEC @collatz
 func.func @collatz() -> tensor<f32> {
diff --git a/tests/e2e/models/edge_detection.mlir b/tests/e2e/models/edge_detection.mlir
index 0e8411d..7c8d53e 100644
--- a/tests/e2e/models/edge_detection.mlir
+++ b/tests/e2e/models/edge_detection.mlir
@@ -1,6 +1,6 @@
-// RUN: iree-run-mlir --iree-input-type=mhlo --iree-hal-target-backends=vmvx %s --input=1x128x128x1xf32 | FileCheck %s
-// RUN: iree-run-mlir --iree-input-type=mhlo --iree-hal-target-backends=llvm-cpu %s --input=1x128x128x1xf32 | FileCheck %s
-// RUN: [[ $IREE_VULKAN_DISABLE == 1 ]] || (iree-run-mlir --iree-input-type=mhlo --iree-hal-target-backends=vulkan-spirv %s --input=1x128x128x1xf32 | FileCheck %s)
+// RUN: iree-run-mlir --Xcompiler,iree-input-type=mhlo --Xcompiler,iree-hal-target-backends=vmvx %s --input=1x128x128x1xf32 | FileCheck %s
+// RUN: iree-run-mlir --Xcompiler,iree-input-type=mhlo --Xcompiler,iree-hal-target-backends=llvm-cpu %s --input=1x128x128x1xf32 | FileCheck %s
+// RUN: [[ $IREE_VULKAN_DISABLE == 1 ]] || (iree-run-mlir --Xcompiler,iree-input-type=mhlo --Xcompiler,iree-hal-target-backends=vulkan-spirv %s --input=1x128x128x1xf32 | FileCheck %s)
 
 // Image edge detection module generated by.
 // https://github.com/openxla/iree/blob/main/samples/colab/edge_detection.ipynb.
diff --git a/tests/e2e/models/fragment_000.mlir b/tests/e2e/models/fragment_000.mlir
index 37981e3..d8365b7 100644
--- a/tests/e2e/models/fragment_000.mlir
+++ b/tests/e2e/models/fragment_000.mlir
@@ -1,6 +1,6 @@
-// RUN: iree-run-mlir --iree-input-type=mhlo --iree-hal-target-backends=vmvx %s | FileCheck %s
-// RUN: iree-run-mlir --iree-input-type=mhlo --iree-hal-target-backends=llvm-cpu %s | FileCheck %s
-// RUN: [[ $IREE_VULKAN_DISABLE == 1 ]] || (iree-run-mlir --iree-input-type=mhlo --iree-hal-target-backends=vulkan-spirv %s | FileCheck %s)
+// RUN: iree-run-mlir --Xcompiler,iree-input-type=mhlo --Xcompiler,iree-hal-target-backends=vmvx %s | FileCheck %s
+// RUN: iree-run-mlir --Xcompiler,iree-input-type=mhlo --Xcompiler,iree-hal-target-backends=llvm-cpu %s | FileCheck %s
+// RUN: [[ $IREE_VULKAN_DISABLE == 1 ]] || (iree-run-mlir --Xcompiler,iree-input-type=mhlo --Xcompiler,iree-hal-target-backends=vulkan-spirv %s | FileCheck %s)
 
 // CHECK-LABEL: EXEC @entry
 func.func @entry() -> tensor<5x5xf32> {
diff --git a/tests/e2e/models/fullyconnected.mlir b/tests/e2e/models/fullyconnected.mlir
index 015d36d..8a07f0c 100644
--- a/tests/e2e/models/fullyconnected.mlir
+++ b/tests/e2e/models/fullyconnected.mlir
@@ -1,5 +1,5 @@
-// RUN: iree-run-mlir --iree-input-type=mhlo --iree-hal-target-backends=llvm-cpu %s --input=1x5xf32=1,-2,-3,4,-5 --input=1x5x3x1xf32=15,14,13,12,11,10,9,8,7,6,5,4,3,2,1 | FileCheck %s
-// RUN: [[ $IREE_VULKAN_DISABLE == 1 ]] || (iree-run-mlir --iree-input-type=mhlo --iree-hal-target-backends=vulkan-spirv %s --input=1x5xf32=1,-2,-3,4,-5 --input=1x5x3x1xf32=15,14,13,12,11,10,9,8,7,6,5,4,3,2,1 | FileCheck %s)
+// RUN: iree-run-mlir --Xcompiler,iree-input-type=mhlo --Xcompiler,iree-hal-target-backends=llvm-cpu %s --input=1x5xf32=1,-2,-3,4,-5 --input=1x5x3x1xf32=15,14,13,12,11,10,9,8,7,6,5,4,3,2,1 | FileCheck %s
+// RUN: [[ $IREE_VULKAN_DISABLE == 1 ]] || (iree-run-mlir --Xcompiler,iree-input-type=mhlo --Xcompiler,iree-hal-target-backends=vulkan-spirv %s --input=1x5xf32=1,-2,-3,4,-5 --input=1x5x3x1xf32=15,14,13,12,11,10,9,8,7,6,5,4,3,2,1 | FileCheck %s)
 
 // CHECK-LABEL: EXEC @main
 func.func @main(%arg0: tensor<1x5xf32>, %arg1: tensor<1x5x3x1xf32>) -> tensor<5x1x5xf32> {
diff --git a/tests/e2e/models/mnist_fake_weights.mlir b/tests/e2e/models/mnist_fake_weights.mlir
index 7892e5a..8f91b32 100644
--- a/tests/e2e/models/mnist_fake_weights.mlir
+++ b/tests/e2e/models/mnist_fake_weights.mlir
@@ -1,8 +1,8 @@
 // MNIST model with placeholder weights, for testing.
 
-// RUN: iree-run-mlir --iree-input-type=mhlo --iree-hal-target-backends=vmvx %s --input=1x28x28x1xf32 | FileCheck %s
-// RUN: iree-run-mlir --iree-input-type=mhlo --iree-hal-target-backends=llvm-cpu %s --input=1x28x28x1xf32 | FileCheck %s
-// RUN: [[ $IREE_VULKAN_DISABLE == 1 ]] || (iree-run-mlir --iree-input-type=mhlo --iree-hal-target-backends=vulkan-spirv %s --input=1x28x28x1xf32 | FileCheck %s)
+// RUN: iree-run-mlir --Xcompiler,iree-input-type=mhlo --Xcompiler,iree-hal-target-backends=vmvx %s --input=1x28x28x1xf32 | FileCheck %s
+// RUN: iree-run-mlir --Xcompiler,iree-input-type=mhlo --Xcompiler,iree-hal-target-backends=llvm-cpu %s --input=1x28x28x1xf32 | FileCheck %s
+// RUN: [[ $IREE_VULKAN_DISABLE == 1 ]] || (iree-run-mlir --Xcompiler,iree-input-type=mhlo --Xcompiler,iree-hal-target-backends=vulkan-spirv %s --input=1x28x28x1xf32 | FileCheck %s)
 
 module {
   util.global private @"__iree_flow___sm_node17__model.layer-1.kernel" {noinline} = dense<1.000000e+00> : tensor<784x128xf32>
diff --git a/tests/e2e/models/resnet50_fake_weights.mlir b/tests/e2e/models/resnet50_fake_weights.mlir
index 2be011a..dc7aecd 100644
--- a/tests/e2e/models/resnet50_fake_weights.mlir
+++ b/tests/e2e/models/resnet50_fake_weights.mlir
@@ -1,8 +1,8 @@
 // ResNet50 model with placeholder weights, for testing.
 // Generated by resnet.ipynb with some manual and automated cleanup for testing.
 
-// RUN: iree-run-mlir --iree-input-type=mhlo --iree-hal-target-backends=llvm-cpu %s --input=1x224x224x3xf32 | FileCheck %s
-// RUN: [[ $IREE_VULKAN_DISABLE == 1 ]] || (iree-run-mlir --iree-input-type=mhlo --iree-hal-target-backends=vulkan-spirv %s --input=1x224x224x3xf32 | FileCheck %s)
+// RUN: iree-run-mlir --Xcompiler,iree-input-type=mhlo --Xcompiler,iree-hal-target-backends=llvm-cpu %s --input=1x224x224x3xf32 | FileCheck %s
+// RUN: [[ $IREE_VULKAN_DISABLE == 1 ]] || (iree-run-mlir --Xcompiler,iree-input-type=mhlo --Xcompiler,iree-hal-target-backends=vulkan-spirv %s --input=1x224x224x3xf32 | FileCheck %s)
 
 module {
   util.global private @"__iree_flow___sm_node188__m.layer-2.kernel" {noinline} = dense<1.000000e+00> : tensor<7x7x3x64xf32>
diff --git a/tests/e2e/models/unidirectional_lstm.mlir b/tests/e2e/models/unidirectional_lstm.mlir
index 8ed02e4..ab28114 100644
--- a/tests/e2e/models/unidirectional_lstm.mlir
+++ b/tests/e2e/models/unidirectional_lstm.mlir
@@ -1,8 +1,8 @@
 // An example LSTM exported from a python reference model with dummy weights.
 
-// RUN: iree-run-mlir --iree-input-type=mhlo --iree-hal-target-backends=llvm-cpu %s --input="1x5xf32=[0,1,0,3,4]" --input="1x5x2x2xf32=[1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20]" | FileCheck %s
-// RUN: [[ $IREE_VMVX_DISABLE == 1 ]] || (iree-run-mlir --iree-input-type=mhlo --iree-hal-target-backends=vmvx %s --input="1x5xf32=[0,1,0,3,4]" --input="1x5x2x2xf32=[1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20]" | FileCheck %s)
-// RUN: [[ $IREE_VULKAN_DISABLE == 1 ]] || (iree-run-mlir --iree-input-type=mhlo --iree-hal-target-backends=vulkan-spirv %s --input="1x5xf32=[0,1,0,3,4]" --input="1x5x2x2xf32=[1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20]" | FileCheck %s)
+// RUN: iree-run-mlir --Xcompiler,iree-input-type=mhlo --Xcompiler,iree-hal-target-backends=llvm-cpu %s --input="1x5xf32=[0,1,0,3,4]" --input="1x5x2x2xf32=[1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20]" | FileCheck %s
+// RUN: [[ $IREE_VMVX_DISABLE == 1 ]] || (iree-run-mlir --Xcompiler,iree-input-type=mhlo --Xcompiler,iree-hal-target-backends=vmvx %s --input="1x5xf32=[0,1,0,3,4]" --input="1x5x2x2xf32=[1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20]" | FileCheck %s)
+// RUN: [[ $IREE_VULKAN_DISABLE == 1 ]] || (iree-run-mlir --Xcompiler,iree-input-type=mhlo --Xcompiler,iree-hal-target-backends=vulkan-spirv %s --input="1x5xf32=[0,1,0,3,4]" --input="1x5x2x2xf32=[1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20]" | FileCheck %s)
 
 // Exported via the XLA HLO Importer
 // The resulting MLIR was modified by hand by changing all large constants to be
diff --git a/tests/e2e/regression/fill_i64.mlir b/tests/e2e/regression/fill_i64.mlir
index aaa45e4..512a7bc 100644
--- a/tests/e2e/regression/fill_i64.mlir
+++ b/tests/e2e/regression/fill_i64.mlir
@@ -1,5 +1,5 @@
-// RUN: iree-run-mlir --iree-hal-target-backends=llvm-cpu %s --input=2x3xi64 | FileCheck %s
-// RUN: [[ $IREE_VMVX_DISABLE == 1 ]] || (iree-run-mlir --iree-hal-target-backends=vmvx %s --input=2x3xi64 | FileCheck %s)
+// RUN: iree-run-mlir --Xcompiler,iree-hal-target-backends=llvm-cpu %s --input=2x3xi64 | FileCheck %s
+// RUN: [[ $IREE_VMVX_DISABLE == 1 ]] || (iree-run-mlir --Xcompiler,iree-hal-target-backends=vmvx %s --input=2x3xi64 | FileCheck %s)
 
 // CHECK: EXEC @fill_i64
 func.func @fill_i64(%arg0: tensor<?x?xi64>) -> (tensor<?x?xi64>, tensor<?x?xi64>) {
diff --git a/tests/e2e/regression/globals.mlir b/tests/e2e/regression/globals.mlir
index 0c2869c..32cf368 100644
--- a/tests/e2e/regression/globals.mlir
+++ b/tests/e2e/regression/globals.mlir
@@ -1,5 +1,5 @@
-// RUN: iree-run-mlir --iree-hal-target-backends=vmvx %s | FileCheck %s
-// RUN: [[ $IREE_VULKAN_DISABLE == 1 ]] || (iree-run-mlir --iree-hal-target-backends=vulkan-spirv %s | FileCheck %s)
+// RUN: iree-run-mlir --Xcompiler,iree-hal-target-backends=vmvx %s | FileCheck %s
+// RUN: [[ $IREE_VULKAN_DISABLE == 1 ]] || (iree-run-mlir --Xcompiler,iree-hal-target-backends=vulkan-spirv %s | FileCheck %s)
 
 util.global private mutable @counter = dense<2.0> : tensor<f32>
 
diff --git a/tests/e2e/regression/globals_ml_program.mlir b/tests/e2e/regression/globals_ml_program.mlir
index aa55d51..48507d1 100644
--- a/tests/e2e/regression/globals_ml_program.mlir
+++ b/tests/e2e/regression/globals_ml_program.mlir
@@ -1,5 +1,5 @@
-// RUN: iree-run-mlir --iree-input-type=mhlo --iree-hal-target-backends=vmvx %s | FileCheck %s
-// RUN: [[ $IREE_VULKAN_DISABLE == 1 ]] || (iree-run-mlir --iree-input-type=mhlo --iree-hal-target-backends=vulkan-spirv %s | FileCheck %s)
+// RUN: iree-run-mlir --Xcompiler,iree-input-type=mhlo --Xcompiler,iree-hal-target-backends=vmvx %s | FileCheck %s
+// RUN: [[ $IREE_VULKAN_DISABLE == 1 ]] || (iree-run-mlir --Xcompiler,iree-input-type=mhlo --Xcompiler,iree-hal-target-backends=vulkan-spirv %s | FileCheck %s)
 
 module {
   ml_program.global private mutable @counter(dense<2.0> : tensor<f32>): tensor<f32>
diff --git a/tests/e2e/regression/scalar.mlir b/tests/e2e/regression/scalar.mlir
index cb24c0d..3226650 100644
--- a/tests/e2e/regression/scalar.mlir
+++ b/tests/e2e/regression/scalar.mlir
@@ -1,6 +1,6 @@
-// RUN: iree-run-mlir --iree-hal-target-backends=vmvx %s | FileCheck %s
-// RUN: iree-run-mlir --iree-hal-target-backends=llvm-cpu %s | FileCheck %s
-// RUN: [[ $IREE_VULKAN_DISABLE == 1 ]] || (iree-run-mlir --iree-hal-target-backends=vulkan-spirv %s | FileCheck %s)
+// RUN: iree-run-mlir --Xcompiler,iree-hal-target-backends=vmvx %s | FileCheck %s
+// RUN: iree-run-mlir --Xcompiler,iree-hal-target-backends=llvm-cpu %s | FileCheck %s
+// RUN: [[ $IREE_VULKAN_DISABLE == 1 ]] || (iree-run-mlir --Xcompiler,iree-hal-target-backends=vulkan-spirv %s | FileCheck %s)
 
 // CHECK-LABEL: EXEC @scalar
 func.func @scalar() -> i32 {
diff --git a/tests/e2e/regression/trace_dispatch_tensors.mlir b/tests/e2e/regression/trace_dispatch_tensors.mlir
index 2712ce4..7fde0c8 100644
--- a/tests/e2e/regression/trace_dispatch_tensors.mlir
+++ b/tests/e2e/regression/trace_dispatch_tensors.mlir
@@ -1,4 +1,8 @@
-// RUN: iree-run-mlir --iree-input-type=mhlo --iree-hal-target-backends=vmvx --iree-flow-trace-dispatch-tensors %s 2>&1 | FileCheck %s
+// RUN: iree-run-mlir \
+// RUN:   --Xcompiler,iree-input-type=mhlo \
+// RUN:   --Xcompiler,iree-hal-target-backends=vmvx \
+// RUN:   --Xcompiler,iree-flow-trace-dispatch-tensors \
+// RUN:   %s 2>&1 | FileCheck %s
 
 func.func @two_dispatch() -> (tensor<5x5xf32>, tensor<3x5xf32>) {
   %0 = util.unfoldable_constant dense<1.0> : tensor<5x3xf32>
diff --git a/tests/e2e/regression/unused_args.mlir b/tests/e2e/regression/unused_args.mlir
index 588daf7..2f2435a 100644
--- a/tests/e2e/regression/unused_args.mlir
+++ b/tests/e2e/regression/unused_args.mlir
@@ -1,6 +1,6 @@
-// RUN: iree-run-mlir --iree-hal-target-backends=vmvx %s --input=4xf32=0 --input=4xf32=1 | FileCheck %s
-// RUN: iree-run-mlir --iree-hal-target-backends=llvm-cpu %s --input=4xf32=0 --input=4xf32=1 | FileCheck %s
-// RUN: [[ $IREE_VULKAN_DISABLE == 1 ]] || (iree-run-mlir --iree-hal-target-backends=vulkan-spirv %s --input=4xf32=0 --input=4xf32=1 | FileCheck %s)
+// RUN: iree-run-mlir --Xcompiler,iree-hal-target-backends=vmvx %s --input=4xf32=0 --input=4xf32=1 | FileCheck %s
+// RUN: iree-run-mlir --Xcompiler,iree-hal-target-backends=llvm-cpu %s --input=4xf32=0 --input=4xf32=1 | FileCheck %s
+// RUN: [[ $IREE_VULKAN_DISABLE == 1 ]] || (iree-run-mlir --Xcompiler,iree-hal-target-backends=vulkan-spirv %s --input=4xf32=0 --input=4xf32=1 | FileCheck %s)
 
 // CHECK-LABEL: EXEC @arg0_unused
 func.func @arg0_unused(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> {
diff --git a/tests/e2e/tensor_ops/tensor_cast.mlir b/tests/e2e/tensor_ops/tensor_cast.mlir
index 16b46f6..13217b9 100644
--- a/tests/e2e/tensor_ops/tensor_cast.mlir
+++ b/tests/e2e/tensor_ops/tensor_cast.mlir
@@ -1,6 +1,6 @@
-// RUN: iree-run-mlir --iree-hal-target-backends=llvm-cpu %s | FileCheck %s
-// RUN: [[ $IREE_VMVX_DISABLE == 1 ]] || (iree-run-mlir --iree-hal-target-backends=vmvx %s | FileCheck %s)
-// RUN: [[ $IREE_VULKAN_DISABLE == 1 ]] || (iree-run-mlir --iree-hal-target-backends=vulkan-spirv %s | FileCheck %s)
+// RUN: iree-run-mlir --Xcompiler,iree-hal-target-backends=llvm-cpu %s | FileCheck %s
+// RUN: [[ $IREE_VMVX_DISABLE == 1 ]] || (iree-run-mlir --Xcompiler,iree-hal-target-backends=vmvx %s | FileCheck %s)
+// RUN: [[ $IREE_VULKAN_DISABLE == 1 ]] || (iree-run-mlir --Xcompiler,iree-hal-target-backends=vulkan-spirv %s | FileCheck %s)
 
 func.func @tensor_cast() -> tensor<2x?xf32> {
   %input = util.unfoldable_constant dense<[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]> : tensor<2x3xf32>
diff --git a/tests/microbenchmarks/dynamic_shape_vectorization.mlir b/tests/microbenchmarks/dynamic_shape_vectorization.mlir
index 7d0fbcd..c988a28 100644
--- a/tests/microbenchmarks/dynamic_shape_vectorization.mlir
+++ b/tests/microbenchmarks/dynamic_shape_vectorization.mlir
@@ -1,4 +1,4 @@
-// RUN: iree-run-mlir --iree-hal-target-backends=llvm-cpu --iree-llvmcpu-link-embedded=true --iree-llvmcpu-target-cpu-features='host' --iree-codegen-llvm-generic-ops-workgroup-size=2048 %s
+// RUN: iree-run-mlir --Xcompiler,iree-hal-target-backends=llvm-cpu --Xcompiler,iree-llvmcpu-target-cpu-features=host --Xcompiler,iree-codegen-llvm-generic-ops-workgroup-size=2048 %s
 
 //===----------------------------------------------------------------------===//
 // Dynamic shape micro-benchmarks.
diff --git a/tests/microbenchmarks/linalg_transpose.mlir b/tests/microbenchmarks/linalg_transpose.mlir
index 427a942..a45093a 100644
--- a/tests/microbenchmarks/linalg_transpose.mlir
+++ b/tests/microbenchmarks/linalg_transpose.mlir
@@ -1,4 +1,4 @@
-// RUN: iree-run-mlir --iree-hal-target-backends=llvm-cpu --iree-llvmcpu-link-embedded=true --iree-llvmcpu-target-cpu-features='host' --iree-codegen-llvm-generic-ops-workgroup-size=2048 %s
+// RUN: iree-run-mlir --Xcompiler,iree-hal-target-backends=llvm-cpu --Xcompiler,iree-llvmcpu-target-cpu-features=host --Xcompiler,iree-codegen-llvm-generic-ops-workgroup-size=2048 %s
 
 //===----------------------------------------------------------------------===//
 // Transpose ops.
diff --git a/tests/microbenchmarks/shared_mem_transpose.mlir b/tests/microbenchmarks/shared_mem_transpose.mlir
index 1ff2f69..6738f09 100644
--- a/tests/microbenchmarks/shared_mem_transpose.mlir
+++ b/tests/microbenchmarks/shared_mem_transpose.mlir
@@ -1,4 +1,4 @@
-// RUN: iree-run-mlir --iree-hal-target-backends=cuda --iree-llvmcpu-link-embedded=true  %s
+// RUN: iree-run-mlir --Xcompiler,iree-hal-target-backends=cuda %s
 
 //===----------------------------------------------------------------------===//
 // Transpose ops.
diff --git a/tools/BUILD.bazel b/tools/BUILD.bazel
index a43c79c..63d1941 100644
--- a/tools/BUILD.bazel
+++ b/tools/BUILD.bazel
@@ -155,50 +155,31 @@
     srcs = ["iree-run-mlir-main.cc"],
     tags = ["hostonly"],
     deps = [
+        "//compiler/bindings/c:headers",
         "//compiler/src:defs",
-        "//compiler/src/iree/compiler/ConstEval",
-        "//compiler/src/iree/compiler/Dialect/HAL/Target",
-        "//compiler/src/iree/compiler/Dialect/VM/Target:init_targets",
-        "//compiler/src/iree/compiler/Dialect/VM/Target/Bytecode",
-        "//compiler/src/iree/compiler/Pipelines",
-        "//compiler/src/iree/compiler/Tools:init_passes_and_dialects",
-        "//compiler/src/iree/compiler/Tools:init_targets",
+        "//compiler/src/iree/compiler/API:Impl",
         "//runtime/src:runtime_defines",
         "//runtime/src/iree/base",
         "//runtime/src/iree/base:tracing",
         "//runtime/src/iree/base/internal:flags",
         "//runtime/src/iree/hal",
-        "//runtime/src/iree/modules/hal:types",
         "//runtime/src/iree/tooling:context_util",
         "//runtime/src/iree/tooling:device_util",
-        "//runtime/src/iree/tooling:vm_util",
+        "//runtime/src/iree/tooling:run_module",
         "//runtime/src/iree/vm",
-        "//runtime/src/iree/vm/bytecode:module",
-        "@llvm-project//llvm:Support",
-        "@llvm-project//mlir:BuiltinToLLVMIRTranslation",
-        "@llvm-project//mlir:IR",
-        "@llvm-project//mlir:LLVMToLLVMIRTranslation",
-        "@llvm-project//mlir:Parser",
-        "@llvm-project//mlir:Pass",
-        "@llvm-project//mlir:Support",
-        "@llvm-project//mlir:ToLLVMIRTranslation",
     ],
 )
 
 iree_runtime_cc_binary(
     name = "iree-run-module",
-    srcs = ["iree-run-module-main.cc"],
+    srcs = ["iree-run-module-main.c"],
     deps = [
         "//runtime/src/iree/base",
         "//runtime/src/iree/base:tracing",
         "//runtime/src/iree/base/internal:flags",
         "//runtime/src/iree/hal",
-        "//runtime/src/iree/modules/hal:types",
-        "//runtime/src/iree/tooling:comparison",
         "//runtime/src/iree/tooling:context_util",
-        "//runtime/src/iree/tooling:device_util",
-        "//runtime/src/iree/tooling:instrument_util",
-        "//runtime/src/iree/tooling:vm_util",
+        "//runtime/src/iree/tooling:run_module",
         "//runtime/src/iree/vm",
     ],
 )
diff --git a/tools/CMakeLists.txt b/tools/CMakeLists.txt
index 4d3badf..3814889 100644
--- a/tools/CMakeLists.txt
+++ b/tools/CMakeLists.txt
@@ -172,18 +172,14 @@
   NAME
     iree-run-module
   SRCS
-    "iree-run-module-main.cc"
+    "iree-run-module-main.c"
   DEPS
     iree::base
     iree::base::internal::flags
     iree::base::tracing
     iree::hal
-    iree::modules::hal::types
-    iree::tooling::comparison
     iree::tooling::context_util
-    iree::tooling::device_util
-    iree::tooling::instrument_util
-    iree::tooling::vm_util
+    iree::tooling::run_module
     iree::vm
 )
 
@@ -300,29 +296,16 @@
     SRCS
       "iree-run-mlir-main.cc"
     DEPS
-      LLVMSupport
-      MLIRIR
-      MLIRParser
-      MLIRPass
-      MLIRSupport
-      MLIRTargetLLVMIRExport
       iree::base
       iree::base::internal::flags
       iree::base::tracing
-      iree::compiler::ConstEval
-      iree::compiler::Dialect::HAL::Target
-      iree::compiler::Dialect::VM::Target::Bytecode
-      iree::compiler::Dialect::VM::Target::init_targets
-      iree::compiler::Pipelines
-      iree::compiler::Tools::init_passes_and_dialects
-      iree::compiler::Tools::init_targets
+      iree::compiler::bindings::c::headers
+      iree::compiler::API::Impl
       iree::hal
-      iree::modules::hal::types
       iree::tooling::context_util
       iree::tooling::device_util
-      iree::tooling::vm_util
+      iree::tooling::run_module
       iree::vm
-      iree::vm::bytecode::module
     DATA
       ${IREE_LLD_TARGET}
     HOSTONLY
diff --git a/tools/iree-run-mlir-main.cc b/tools/iree-run-mlir-main.cc
index 1593ec5..63bfd5a 100644
--- a/tools/iree-run-mlir-main.cc
+++ b/tools/iree-run-mlir-main.cc
@@ -4,13 +4,12 @@
 // See https://llvm.org/LICENSE.txt for license information.
 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
 
-// IREE source.mlir -> execution output runner.
+// IREE source.mlir/mlirbc -> execution output runner.
 // This is meant to be called from LIT for FileCheck tests or as a developer
 // tool to emulate what an online compiler does. It tries to match the interface
-// of iree-compile and iree-opt (featuring -split-input-file, etc) so it's
-// easy to run tests or approximate an iree-compile | iree-run-module sequence.
-// If you want a more generalized runner for standalone precompiled IREE modules
-// use iree-run-module instead.
+// of iree-compile so it's easy to run tests or approximate an
+// `iree-compile | iree-run-module` sequence. If you want a more generalized
+// runner for standalone precompiled IREE modules use iree-run-module instead.
 //
 // If there's a single exported function that will be executed and if there are
 // multiple functions --function= can be used to specify which is executed.
@@ -18,24 +17,45 @@
 // function will be printed to stdout for checking or can be written to files
 // with --output=.
 //
-// Example input:
-// // RUN: iree-run-mlir %s | FileCheck %s
-// // CHECK-LABEL: @foo
-// // CHECK: 2xf32=[2 3]
-// func.func @foo() -> tensor<2xf32> {
-//   %0 = arith.constant dense<[2.0, 3.0]> : tensor<2xf32>
-//   return %0 : tensor<2xf32>
-// }
+// Similar to iree-run-module the --device= flag can be used to specify which
+// drivers and devices should be used to execute the function. The tool will
+// try to infer which iree-compile flags are required for the devices used but
+// this can be overridden by passing the --iree-hal-target-backends= and related
+// flags explicitly. Likewise if only the target backend is specified the
+// devices to use will be inferred unless explicitly specified.
+//
+// Example usage to compile and run with CUDA:
+// $ iree-run-mlir --device=cuda://0 file.mlir
+// or to compile with the LLVM CPU backend and run with the local-task driver:
+// $ iree-run-mlir file.mlir \
+//       --Xcompiler,iree-hal-target-backends=llvm-cpu --device=local-task
+//
+// Example usage in a lit test:
+//   // RUN: iree-run-mlir --device= %s --function=foo --input=2xf32=2,3 | \
+//   // RUN:   FileCheck %s
+//   // CHECK-LABEL: @foo
+//   // CHECK: 2xf32=[2 3]
+//   func.func @foo(%arg0: tensor<2xf32>) -> tensor<2xf32> {
+//     return %arg0 : tensor<2xf32>
+//   }
 //
 // Command line arguments are handled by LLVM's parser by default but -- can be
 // used to separate the compiler flags from the runtime flags, such as:
-//   iree-run-mlir --iree-hal-target-backends=llvm-cpu -- --device=local-task
+// $ iree-run-mlir source.mlir --device=local-task -- \
+//       --iree-hal-target-backends=llvm-cpu
+//
+// In addition compiler/runtime flags can be passed in any order by prefixing
+// them with --Xcompiler or --Xruntime like `--Xruntime,device=local-task` or
+// `--Xruntime --device=local-task`.
 
 #include <cstdio>
 #include <cstring>
 #include <functional>
 #include <memory>
+#include <optional>
+#include <set>
 #include <string>
+#include <string_view>
 #include <tuple>
 #include <type_traits>
 #include <utility>
@@ -44,561 +64,477 @@
 #include "iree/base/api.h"
 #include "iree/base/internal/flags.h"
 #include "iree/base/tracing.h"
-#include "iree/compiler/ConstEval/Passes.h"
-#include "iree/compiler/Dialect/HAL/Target/TargetBackend.h"
-#include "iree/compiler/Dialect/VM/Target/Bytecode/BytecodeModuleTarget.h"
-#include "iree/compiler/Dialect/VM/Target/init_targets.h"
-#include "iree/compiler/Pipelines/Pipelines.h"
-#include "iree/compiler/Tools/init_dialects.h"
-#include "iree/compiler/Tools/init_targets.h"
+#include "iree/compiler/embedding_api.h"
 #include "iree/hal/api.h"
-#include "iree/modules/hal/types.h"
 #include "iree/tooling/context_util.h"
 #include "iree/tooling/device_util.h"
-#include "iree/tooling/vm_util.h"
+#include "iree/tooling/run_module.h"
 #include "iree/vm/api.h"
-#include "iree/vm/bytecode/module.h"
-#include "llvm/ADT/STLExtras.h"
-#include "llvm/ADT/SmallVector.h"
-#include "llvm/ADT/StringRef.h"
-#include "llvm/ADT/Twine.h"
-#include "llvm/ADT/iterator.h"
-#include "llvm/ADT/iterator_range.h"
-#include "llvm/Support/Casting.h"
-#include "llvm/Support/CommandLine.h"
-#include "llvm/Support/Error.h"
-#include "llvm/Support/ErrorHandling.h"
-#include "llvm/Support/InitLLVM.h"
-#include "llvm/Support/MemoryBuffer.h"
-#include "llvm/Support/SMLoc.h"
-#include "llvm/Support/SourceMgr.h"
-#include "llvm/Support/raw_ostream.h"
-#include "mlir/IR/AsmState.h"
-#include "mlir/IR/BlockSupport.h"
-#include "mlir/IR/BuiltinAttributes.h"
-#include "mlir/IR/BuiltinOps.h"
-#include "mlir/IR/Dialect.h"
-#include "mlir/IR/MLIRContext.h"
-#include "mlir/IR/OpDefinition.h"
-#include "mlir/IR/Operation.h"
-#include "mlir/IR/OwningOpRef.h"
-#include "mlir/Parser/Parser.h"
-#include "mlir/Pass/PassManager.h"
-#include "mlir/Support/FileUtilities.h"
-#include "mlir/Support/LogicalResult.h"
-#include "mlir/Target/LLVMIR/Dialect/Builtin/BuiltinToLLVMIRTranslation.h"
-#include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h"
-
-static llvm::cl::opt<std::string> input_file_flag{
-    llvm::cl::Positional,
-    llvm::cl::desc("<input .mlir file>"),
-    llvm::cl::init("-"),
-};
-
-static llvm::cl::opt<bool> verify_passes_flag(
-    "verify-each",
-    llvm::cl::desc("Run the verifier after each transformation pass"),
-    llvm::cl::init(true));
-
-static llvm::cl::opt<bool> print_mlir_flag{
-    "print-mlir",
-    llvm::cl::desc("Prints MLIR IR after translation"),
-    llvm::cl::init(false),
-};
-
-static llvm::cl::opt<bool> print_annotated_mlir_flag{
-    "print-annotated-mlir",
-    llvm::cl::desc("Prints MLIR IR with final serialization annotations"),
-    llvm::cl::init(false),
-};
-
-static llvm::cl::opt<bool> print_flatbuffer_flag{
-    "print-flatbuffer",
-    llvm::cl::desc("Prints Flatbuffer text after serialization"),
-    llvm::cl::init(false),
-};
-
-static llvm::cl::opt<std::string> output_file_flag{
-    "o",
-    llvm::cl::desc("File path in which to write the compiled module file"),
-    llvm::cl::init(""),
-};
-
-static llvm::cl::opt<bool> run_flag{
-    "run",
-    llvm::cl::desc("Runs the module (vs. just compiling and verifying)"),
-    llvm::cl::init(true),
-};
-
-static llvm::cl::list<std::string> run_args_flag{
-    "run-arg",
-    llvm::cl::desc("Argument passed to the execution flag parser"),
-    llvm::cl::ConsumeAfter,
-};
-
-IREE_FLAG(string, function, "",
-          "Name of a function contained in the compiled module. If omitted\n"
-          "and there's a single exported function that will be run instead.");
-
-IREE_FLAG_LIST(
-    string, input,
-    "An input (a) value or (b) buffer of the format:\n"
-    "  (a) scalar value\n"
-    "     value\n"
-    "     e.g.: --input=\"3.14\"\n"
-    "  (b) buffer:\n"
-    "     [shape]xtype=[value]\n"
-    "     e.g.: --input=\"2x2xi32=1 2 3 4\"\n"
-    "Optionally, brackets may be used to separate the element values:\n"
-    "  2x2xi32=[[1 2][3 4]]\n"
-    "Raw binary files can be read to provide buffer contents:\n"
-    "  2x2xi32=@some/file.bin\n"
-    "\n"
-    "Numpy npy files from numpy.save can be read to provide 1+ values:\n"
-    "  @some.npy\n"
-    "\n"
-    "Each occurrence of the flag indicates an input in the order they were\n"
-    "specified on the command line.");
-
-IREE_FLAG_LIST(
-    string, output,
-    "Specifies how to handle an output from the invocation:\n"
-    "  `` (empty): ignore output\n"
-    "     e.g.: --output=\n"
-    "  `-`: print textual form to stdout\n"
-    "     e.g.: --output=-\n"
-    "  `@file.npy`: create/overwrite a numpy npy file and write buffer view\n"
-    "     e.g.: --output=@file.npy\n"
-    "  `+file.npy`: create/append a numpy npy file and write buffer view\n"
-    "     e.g.: --output=+file.npy\n"
-    "\n"
-    "Numpy npy files can be read in Python using numpy.load, for example an\n"
-    "invocation producing two outputs can be concatenated as:\n"
-    "    --output=@file.npy --output=+file.npy\n"
-    "And then loaded in Python by reading from the same file:\n"
-    "  with open('file.npy', 'rb') as f:\n"
-    "    print(numpy.load(f))\n"
-    "    print(numpy.load(f))\n"
-    "\n"
-    "Each occurrence of the flag indicates an output in the order they were\n"
-    "specified on the command line.");
-
-IREE_FLAG(int32_t, output_max_element_count, 1024,
-          "Prints up to the maximum number of elements of output tensors, "
-          "eliding the remainder.");
 
 namespace iree {
 namespace {
 
-// Tries to guess a default device name from the backend, where possible.
+// Polyfill for std::string_view::starts_with, coming in C++ 20.
+// https://en.cppreference.com/w/cpp/string/basic_string_view/starts_with
+bool starts_with(std::string_view prefix, std::string_view in_str) {
+  return in_str.size() >= prefix.size() &&
+         in_str.compare(0, prefix.size(), prefix) == 0;
+}
+
+// Tries to guess a default device name from the |target_backend| when possible.
 // Users are still able to override this by passing in --device= flags.
-std::string InferDefaultDeviceFromBackend(const std::string& backend) {
-  if (backend == "vmvx" || backend == "llvm-cpu") {
-    return "local-task";
-  } else if (backend == "vmvx-inline") {
+std::string InferDefaultDeviceFromTargetBackend(
+    std::string_view target_backend) {
+  if (target_backend == "" || target_backend == "vmvx-inline") {
+    // Plain VM or vmvx-inline targets do not need a HAL device.
     return "";
+  } else if (target_backend == "llvm-cpu" || target_backend == "vmvx") {
+    // Locally-executable targets default to the multithreaded task system
+    // driver; users can override by specifying --device=local-sync instead.
+    return "local-task";
   }
-  size_t dash = backend.find('-');
+  // Many other backends use the `driver-pipeline` naming like `vulkan-spirv`
+  // and we try that; device creation will fail if it's a bad guess.
+  size_t dash = target_backend.find('-');
   if (dash == std::string::npos) {
-    return backend;
+    return std::string(target_backend);
   } else {
-    return backend.substr(0, dash);
+    return std::string(target_backend.substr(0, dash));
   }
 }
 
-// Returns a list of target compiler backends to use for file evaluation.
-Status GetTargetBackends(std::vector<std::string>* out_target_backends) {
-  IREE_TRACE_SCOPE();
-  out_target_backends->clear();
-  auto target_backends =
-      mlir::iree_compiler::IREE::HAL::TargetOptions::FromFlags::get().targets;
-  if (target_backends.empty()) {
-    iree_allocator_t host_allocator = iree_allocator_system();
-    iree_host_size_t driver_info_count = 0;
-    iree_hal_driver_info_t* driver_infos = NULL;
-    IREE_RETURN_IF_ERROR(iree_hal_driver_registry_enumerate(
-        iree_hal_available_driver_registry(), host_allocator,
-        &driver_info_count, &driver_infos));
-    for (iree_host_size_t i = 0; i < driver_info_count; ++i) {
-      target_backends.push_back(std::string(driver_infos[i].driver_name.data,
-                                            driver_infos[i].driver_name.size));
+// Tries to guess a target backend from the given |device_uri| when possible.
+// Returns empty string if no backend is required or one could not be inferred.
+std::string InferTargetBackendFromDevice(iree_string_view_t device_uri) {
+  // Get the driver name from URIs in the `driver://...` form.
+  iree_string_view_t driver = iree_string_view_empty();
+  iree_string_view_split(device_uri, ':', &driver, nullptr);
+  if (iree_string_view_is_empty(driver)) {
+    // Plain VM or vmvx-inline targets do not need a HAL device.
+    return "";
+  } else if (iree_string_view_starts_with(driver, IREE_SV("local-"))) {
+    // Locally-executable devices default to the llvm-cpu target as that's
+    // usually what people want for CPU execution; users can override by
+    // specifying --iree-hal-target-backends=vmvx instead.
+    return "llvm-cpu";
+  }
+  // Many other backends have aliases that allow using the driver name. If there
+  // are multiple pipelines available whatever the compiler defaults to is
+  // chosen.
+  return std::string(driver.data, driver.size);
+}
+
+// Tries to guess a set of target backends from the |device_flag_values| when
+// possible. Since multiple target backends can be used for a particular device
+// (such as llvm-cpu or vmvx for local-sync and local-task) this is just
+// guesswork. If we can't produce a target backend flag value we bail.
+// Returns a comma-delimited list of target backends.
+StatusOr<std::string> InferTargetBackendsFromDevices(
+    iree_host_size_t device_flag_count,
+    const iree_string_view_t* device_flag_values) {
+  // No-op when no devices specified (probably no HAL).
+  if (device_flag_count == 0) return "";
+  // If multiple devices were provided we need to target all of them.
+  std::set<std::string> target_backends;
+  for (iree_host_size_t i = 0; i < device_flag_count; ++i) {
+    auto target_backend = InferTargetBackendFromDevice(device_flag_values[i]);
+    if (!target_backend.empty()) {
+      target_backends.insert(std::move(target_backend));
     }
-    iree_allocator_free(host_allocator, driver_infos);
   }
-  *out_target_backends = std::move(target_backends);
-  return OkStatus();
+  // Join all target backends together.
+  std::string result;
+  for (auto& target_backend : target_backends) {
+    if (!result.empty()) result.append(",");
+    result.append(target_backend);
+  }
+  return result;
 }
 
-void BuildDefaultIREEVMTransformPassPipeline(mlir::OpPassManager& passManager) {
-  static mlir::iree_compiler::IREEVMPipelineHooks defaultHooks = {
-      // buildConstEvalPassPipelineCallback =
-      [](mlir::OpPassManager& pm) {
-        pm.addPass(mlir::iree_compiler::ConstEval::createJitGlobalsPass());
-      }};
+// Configures the --iree-hal-target-backends= flag based on the --device= flags
+// set by the user. Ignored if any target backends are explicitly specified.
+// Online compilers would want to do some more intelligent device selection on
+// their own.
+Status ConfigureTargetBackends(iree_compiler_session_t* session,
+                               std::string* out_default_device_uri) {
+  // Query the session for the currently set --iree-hal-target-backends= flag.
+  // It may be empty string.
+  std::string target_backends_flag;
+  ireeCompilerSessionGetFlags(
+      session, /*nonDefaultOnly=*/true,
+      [](const char* flag_str, size_t length, void* user_data) {
+        // NOTE: flag_str has the full `--flag=value` string.
+        std::string_view prefix = "--iree-hal-target-backends=";
+        std::string_view flag = std::string_view(flag_str, length);
+        if (starts_with(prefix, flag)) {
+          flag.remove_prefix(prefix.size());
+          if (flag.empty()) return;  // ignore empty
+          auto* result = static_cast<std::string*>(user_data);
+          *result = std::string(flag);
+        }
+      },
+      static_cast<void*>(&target_backends_flag));
 
-  buildIREEVMTransformPassPipeline(
-      mlir::iree_compiler::BindingOptions::FromFlags::get(),
-      mlir::iree_compiler::InputDialectOptions::FromFlags::get(),
-      mlir::iree_compiler::PreprocessingOptions::FromFlags::get(),
-      mlir::iree_compiler::HighLevelOptimizationOptions::FromFlags::get(),
-      mlir::iree_compiler::SchedulingOptions::FromFlags::get(),
-      mlir::iree_compiler::IREE::HAL::TargetOptions::FromFlags::get(),
-      mlir::iree_compiler::IREE::VM::TargetOptions::FromFlags::get(),
-      defaultHooks, passManager);
-}
+  // Query the tooling utils for the --device= flag values. Note that zero or
+  // more devices may be specified.
+  iree_host_size_t device_flag_count = 0;
+  const iree_string_view_t* device_flag_values = NULL;
+  iree_hal_get_devices_flag_list(&device_flag_count, &device_flag_values);
 
-// Prepares a module for evaluation by running MLIR import and IREE translation.
-// Returns the serialized flatbuffer data.
-Status PrepareModule(std::string target_backend,
-                     std::unique_ptr<llvm::MemoryBuffer> file_buffer,
-                     mlir::DialectRegistry& registry, std::string* out_module) {
-  IREE_TRACE_SCOPE();
-  out_module->clear();
-
-  mlir::MLIRContext context;
-  context.appendDialectRegistry(registry);
-  context.allowUnregisteredDialects();
-
-  // Parse input MLIR module.
-  llvm::SourceMgr source_mgr;
-  source_mgr.AddNewSourceBuffer(std::move(file_buffer), llvm::SMLoc());
-  mlir::OwningOpRef<mlir::ModuleOp> mlir_module =
-      mlir::parseSourceFile<mlir::ModuleOp>(source_mgr, &context);
-  if (!mlir_module) {
-    return iree_make_status(IREE_STATUS_FAILED_PRECONDITION,
-                            "could not parse MLIR file");
+  // No-op if no target backends or devices are specified - this can be an
+  // intentional decision as the user may be running a program that doesn't use
+  // the HAL.
+  if (target_backends_flag.empty() && device_flag_count == 0) {
+    return OkStatus();
   }
 
-  // Translate from MLIR to IREE bytecode.
-  printf("Compiling for target backend '%s'...\n", target_backend.c_str());
-  mlir::PassManager pass_manager(mlir_module->getContext());
-  pass_manager.enableVerifier(verify_passes_flag);
-  if (failed(mlir::applyPassManagerCLOptions(pass_manager))) {
-    return iree_make_status(IREE_STATUS_INTERNAL,
-                            "failed to apply pass manager CL options");
-  }
-  mlir::applyDefaultTimingPassManagerCLOptions(pass_manager);
-  BuildDefaultIREEVMTransformPassPipeline(pass_manager);
-  if (failed(pass_manager.run(mlir_module.get()))) {
-    return iree_make_status(IREE_STATUS_INTERNAL,
-                            "conversion from source -> vm failed");
+  // No-op if both target backends and devices are set as the user has
+  // explicitly specified a configuration.
+  if (!target_backends_flag.empty() && device_flag_count > 0) {
+    return OkStatus();
   }
 
-  if (print_mlir_flag) {
-    mlir_module->dump();
-  }
-
-  // NOTE: if we have an output file specified then we could compile into that
-  // for greater efficiency. Today we assume that users aren't passing multi-GB
-  // models through this tool (or if they are they have the memory to run them).
-  auto vm_options =
-      mlir::iree_compiler::IREE::VM::TargetOptions::FromFlags::get();
-  auto bytecode_options =
-      mlir::iree_compiler::IREE::VM::BytecodeTargetOptions::FromFlags::get();
-  std::string binary_contents;
-  llvm::raw_string_ostream binary_output(binary_contents);
-  if (failed(mlir::iree_compiler::IREE::VM::translateModuleToBytecode(
-          mlir_module.get(), vm_options, bytecode_options, binary_output))) {
-    return iree_make_status(
-        IREE_STATUS_INTERNAL,
-        "serialization to flatbuffer bytecode (binary) failed");
-  }
-  binary_output.flush();
-
-  // Print the annotated MLIR and flatbuffer; easiest way right now is to just
-  // do it all again.
-  if (print_annotated_mlir_flag) {
-    bytecode_options.outputFormat =
-        mlir::iree_compiler::IREE::VM::BytecodeOutputFormat::kAnnotatedMlirText;
-    std::string text_contents;
-    llvm::raw_string_ostream text_output(text_contents);
-    if (failed(mlir::iree_compiler::IREE::VM::translateModuleToBytecode(
-            mlir_module.get(), vm_options, bytecode_options, text_output))) {
-      return iree_make_status(IREE_STATUS_INTERNAL,
-                              "serialization to annotated MLIR (text) failed");
-    }
-    text_output.flush();
-    fprintf(stderr, "%s\n", text_contents.c_str());
-  }
-  if (print_flatbuffer_flag) {
-    bytecode_options.outputFormat =
-        mlir::iree_compiler::IREE::VM::BytecodeOutputFormat::kFlatBufferText;
-    std::string text_contents;
-    llvm::raw_string_ostream text_output(text_contents);
-    if (failed(mlir::iree_compiler::IREE::VM::translateModuleToBytecode(
-            mlir_module.get(), vm_options, bytecode_options, text_output))) {
+  // If target backends are specified then we can infer the runtime devices from
+  // the compiler configuration. This only works if there's a single backend
+  // specified; if the user wants multiple target backends then they must
+  // specify the device(s) to use.
+  if (device_flag_count == 0) {
+    if (target_backends_flag.find(',') != std::string::npos) {
       return iree_make_status(
-          IREE_STATUS_INTERNAL,
-          "serialization to flatbuffer bytecode (text) failed");
+          IREE_STATUS_INVALID_ARGUMENT,
+          "if multiple target backends are specified the device to use must "
+          "also be specified with --device= (have "
+          "`--iree-hal-target-backends=%.*s`)",
+          (int)target_backends_flag.size(), target_backends_flag.data());
     }
-    text_output.flush();
-    fprintf(stderr, "%s\n", text_contents.c_str());
-  }
-  if (!output_file_flag.empty()) {
-    if (llvm::writeToOutput(
-            output_file_flag, [&](llvm::raw_ostream& os) -> llvm::Error {
-              os.write(binary_contents.data(), binary_contents.size());
-              return llvm::Error::success();
-            })) {
-      return iree_make_status(IREE_STATUS_PERMISSION_DENIED,
-                              "unable to write module output to %s",
-                              output_file_flag.c_str());
-    }
-  }
-
-  *out_module = std::move(binary_contents);
-  return OkStatus();
-}
-
-// Evaluates a single function in its own fiber, printing the results to stdout.
-Status EvaluateFunction(iree_vm_context_t* context, iree_hal_device_t* device,
-                        iree_hal_allocator_t* device_allocator,
-                        iree_vm_function_t function,
-                        iree_string_view_t function_name) {
-  IREE_TRACE_SCOPE();
-  iree_allocator_t host_allocator = iree_allocator_system();
-
-  printf("EXEC @%.*s\n", (int)function_name.size, function_name.data);
-
-  // Parse input values from the flags.
-  vm::ref<iree_vm_list_t> inputs;
-  IREE_RETURN_IF_ERROR(iree_tooling_parse_to_variant_list(
-      device_allocator, FLAG_input_list().values, FLAG_input_list().count,
-      host_allocator, &inputs));
-
-  // If the function is async add fences so we can invoke it synchronously.
-  vm::ref<iree_hal_fence_t> finish_fence;
-  IREE_RETURN_IF_ERROR(iree_tooling_append_async_fence_inputs(
-      inputs.get(), &function, device, /*wait_fence=*/NULL, &finish_fence));
-
-  // Prepare outputs list to accept the results from the invocation.
-  vm::ref<iree_vm_list_t> outputs;
-  IREE_RETURN_IF_ERROR(iree_vm_list_create(iree_vm_make_undefined_type_def(),
-                                           16, host_allocator, &outputs));
-
-  // Synchronously invoke the function.
-  IREE_RETURN_IF_ERROR(iree_vm_invoke(
-      context, function, IREE_VM_INVOCATION_FLAG_NONE,
-      /*policy=*/nullptr, inputs.get(), outputs.get(), host_allocator));
-
-  // If the function is async we need to wait for it to complete.
-  if (finish_fence) {
-    IREE_RETURN_IF_ERROR(
-        iree_hal_fence_wait(finish_fence.get(), iree_infinite_timeout()));
-  }
-
-  // Print outputs.
-  if (FLAG_output_list().count == 0) {
-    IREE_RETURN_IF_ERROR(
-        iree_tooling_variant_list_fprint(
-            IREE_SV("result"), outputs.get(),
-            (iree_host_size_t)FLAG_output_max_element_count, stdout),
-        "printing results");
-  } else {
-    IREE_RETURN_IF_ERROR(
-        iree_tooling_output_variant_list(
-            outputs.get(), FLAG_output_list().values, FLAG_output_list().count,
-            (iree_host_size_t)FLAG_output_max_element_count, stdout),
-        "outputting results");
-  }
-
-  return OkStatus();
-}
-
-// Evaluates all exported functions within given module.
-Status EvaluateFunctions(iree_vm_instance_t* instance,
-                         const std::string& default_device_uri,
-                         const std::string& flatbuffer_data) {
-  IREE_TRACE_SCOPE0("EvaluateFunctions");
-
-  // Load any custom modules the user may have explicitly specified and then
-  // append the module we compiled.
-  iree_tooling_module_list_t module_list;
-  iree_tooling_module_list_initialize(&module_list);
-  IREE_RETURN_IF_ERROR(iree_tooling_load_modules_from_flags(
-      instance, iree_allocator_system(), &module_list));
-
-  // Load the bytecode module from the flatbuffer data.
-  // We do this first so that if we fail validation we know prior to dealing
-  // with devices.
-  vm::ref<iree_vm_module_t> main_module;
-  IREE_RETURN_IF_ERROR(iree_vm_bytecode_module_create(
-      instance,
-      iree_make_const_byte_span((void*)flatbuffer_data.data(),
-                                flatbuffer_data.size()),
-      iree_allocator_null(), iree_allocator_system(), &main_module));
-  IREE_RETURN_IF_ERROR(
-      iree_tooling_module_list_push_back(&module_list, main_module.get()));
-
-  if (!run_flag) {
-    // Just wanted verification; return without running.
-    main_module.reset();
-    iree_tooling_module_list_reset(&module_list);
+    *out_default_device_uri =
+        InferDefaultDeviceFromTargetBackend(target_backends_flag);
     return OkStatus();
   }
 
-  // Choose which function to run - either the one specified in the flag or the
-  // only exported non-internal function.
-  iree_vm_function_t function = {0};
-  if (strlen(FLAG_function) == 0) {
-    IREE_RETURN_IF_ERROR(iree_tooling_find_single_exported_function(
-        main_module.get(), &function));
-  } else {
-    IREE_RETURN_IF_ERROR(iree_vm_module_lookup_function_by_name(
-                             main_module.get(), IREE_VM_FUNCTION_LINKAGE_EXPORT,
-                             iree_make_cstring_view(FLAG_function), &function),
-                         "looking up function '%s'", FLAG_function);
+  // Infer target backends from the runtime device configuration.
+  // This can get arbitrarily complex but for now this simple runner just
+  // guesses. In the future we'll have more ways of configuring the compiler
+  // from available runtime devices (not just the target backend but also
+  // target-specific settings).
+  IREE_ASSIGN_OR_RETURN(
+      auto target_backends,
+      InferTargetBackendsFromDevices(device_flag_count, device_flag_values));
+  if (!target_backends.empty()) {
+    auto target_backends_flag =
+        std::string("--iree-hal-target-backends=") + target_backends;
+    const char* compiler_argv[1] = {
+        target_backends_flag.c_str(),
+    };
+    if (auto error = ireeCompilerSessionSetFlags(
+            session, IREE_ARRAYSIZE(compiler_argv), compiler_argv)) {
+      return iree_make_status(
+          IREE_STATUS_INVALID_ARGUMENT,
+          "unable to set inferred target backend flag to `%.*s`",
+          (int)target_backends_flag.size(), target_backends_flag.data());
+    }
   }
 
-  // Evaluate all exported functions.
-  auto run_function = [&](iree_vm_function_t function) -> Status {
-    iree_string_view_t function_name = iree_vm_function_name(&function);
-
-    // Create the context we'll use for this (ensuring that we can't interfere
-    // with other running evaluations, such as when in a multithreaded test
-    // runner).
-    vm::ref<iree_vm_context_t> context;
-    vm::ref<iree_hal_device_t> device;
-    vm::ref<iree_hal_allocator_t> device_allocator;
-    IREE_RETURN_IF_ERROR(iree_tooling_create_context_from_flags(
-        instance, module_list.count, module_list.values,
-        iree_make_string_view(default_device_uri.data(),
-                              default_device_uri.size()),
-        iree_allocator_system(), &context, &device, &device_allocator));
-
-    IREE_RETURN_IF_ERROR(iree_hal_begin_profiling_from_flags(device.get()));
-
-    // Invoke the function and print results.
-    IREE_RETURN_IF_ERROR(
-        EvaluateFunction(context.get(), device.get(), device_allocator.get(),
-                         function, function_name),
-        "evaluating export function %.*s", (int)function_name.size,
-        function_name.data);
-
-    IREE_RETURN_IF_ERROR(iree_hal_end_profiling_from_flags(device.get()));
-
-    context.reset();
-    device_allocator.reset();
-    device.reset();
-    return OkStatus();
-  };
-  Status evaluate_status = run_function(function);
-
-  main_module.reset();
-  iree_tooling_module_list_reset(&module_list);
-
-  return evaluate_status;
-}
-
-// Translates and runs a single LLVM file buffer.
-Status EvaluateFile(std::unique_ptr<llvm::MemoryBuffer> file_buffer,
-                    mlir::DialectRegistry& registry) {
-  IREE_TRACE_SCOPE0("EvaluateFile");
-
-  vm::ref<iree_vm_instance_t> instance;
-  IREE_RETURN_IF_ERROR(
-      iree_tooling_create_instance(iree_allocator_system(), &instance),
-      "Creating instance");
-
-  std::vector<std::string> target_backends;
-  IREE_RETURN_IF_ERROR(GetTargetBackends(&target_backends));
-  for (const auto& target_backend : target_backends) {
-    // Prepare the module for execution and evaluate it.
-    IREE_TRACE_FRAME_MARK();
-    auto cloned_file_buffer = llvm::MemoryBuffer::getMemBufferCopy(
-        file_buffer->getBuffer(), file_buffer->getBufferIdentifier());
-    std::string flatbuffer_data;
-    IREE_RETURN_IF_ERROR(
-        PrepareModule(target_backend + '*', std::move(cloned_file_buffer),
-                      registry, &flatbuffer_data),
-        "Translating module");
-    IREE_TRACE_FRAME_MARK();
-    std::string default_device_uri =
-        InferDefaultDeviceFromBackend(target_backend);
-    IREE_RETURN_IF_ERROR(
-        EvaluateFunctions(instance.get(), default_device_uri, flatbuffer_data),
-        "Evaluating functions");
-  }
-
-  instance.reset();
   return OkStatus();
 }
 
 // Runs the given .mlir file based on the current flags.
-Status RunFile(const std::string& mlir_filename,
-               mlir::DialectRegistry& registry) {
-  IREE_TRACE_SCOPE0("RunFile");
+StatusOr<int> CompileAndRunFile(iree_compiler_session_t* session,
+                                const char* mlir_filename) {
+  IREE_TRACE_SCOPE0("CompileAndRunFile");
 
-  // Load input file/from stdin.
-  std::string error_message;
-  auto file = mlir::openInputFile(mlir_filename, &error_message);
-  if (!file) {
-    return iree_make_status(
-        IREE_STATUS_NOT_FOUND, "unable to open input file %.*s: %s",
-        (int)mlir_filename.size(), mlir_filename.data(), error_message.c_str());
+  // Configure the --iree-hal-target-backends= flag and/or get the default
+  // device to use at runtime if either are not explicitly specified.
+  // Note that target backends and the runtime devices aren't 1:1 and this is
+  // an imperfect guess. In this simple online compiler we assume homogenous
+  // device sets and only a single global target backend but library/hosting
+  // layers can configure heterogenous and per-invocation target configurations.
+  std::string default_device_uri;
+  IREE_RETURN_IF_ERROR(ConfigureTargetBackends(session, &default_device_uri));
+
+  // RAII container for the compiler invocation.
+  struct InvocationState {
+    iree_compiler_invocation_t* invocation = nullptr;
+    iree_compiler_source_t* source = nullptr;
+    iree_compiler_output_t* output = nullptr;
+
+    explicit InvocationState(iree_compiler_session_t* session) {
+      invocation = ireeCompilerInvocationCreate(session);
+    }
+
+    ~InvocationState() {
+      if (source) ireeCompilerSourceDestroy(source);
+      if (output) ireeCompilerOutputDestroy(output);
+      ireeCompilerInvocationDestroy(invocation);
+    }
+
+    Status emitError(iree_compiler_error_t* error,
+                     iree_status_code_t status_code,
+                     std::string_view while_performing = "") {
+      const char* msg = ireeCompilerErrorGetMessage(error);
+      fprintf(stderr, "error compiling input file: %s\n", msg);
+      iree_status_t status = iree_make_status(status_code, msg);
+      if (!while_performing.empty()) {
+        status = iree_status_annotate(
+            status, iree_make_string_view(while_performing.data(),
+                                          while_performing.size()));
+      }
+      ireeCompilerErrorDestroy(error);
+      return status;
+    }
+  } state(session);
+
+  // Open the source file on disk or stdin if `-`.
+  if (auto error =
+          ireeCompilerSourceOpenFile(session, mlir_filename, &state.source)) {
+    return state.emitError(error, IREE_STATUS_NOT_FOUND, "opening source file");
   }
 
-  // Use entire buffer as a single module.
-  return EvaluateFile(std::move(file), registry);
+  // Open a writeable memory buffer that we can stream the compilation outputs
+  // into. This may be backed by a memory-mapped file to allow for very large
+  // results.
+  if (auto error = ireeCompilerOutputOpenMembuffer(&state.output)) {
+    return state.emitError(error, IREE_STATUS_INTERNAL,
+                           "open output memory buffer");
+  }
+
+  // TODO: make parsing/pipeline execution return an error object.
+  // We could capture diagnostics, stash them on the state, and emit with
+  // ireeCompilerInvocationEnableCallbackDiagnostics.
+  // For now we route all errors to stderr.
+  ireeCompilerInvocationEnableConsoleDiagnostics(state.invocation);
+
+  // Parse the source MLIR input and log verbose errors. Syntax errors or
+  // version mismatches will hit here.
+  if (!ireeCompilerInvocationParseSource(state.invocation, state.source)) {
+    return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
+                            "failed to parse input file");
+  }
+
+  // Invoke the standard compilation pipeline to produce the compiled module.
+  if (!ireeCompilerInvocationPipeline(state.invocation,
+                                      IREE_COMPILER_PIPELINE_STD)) {
+    return iree_make_status(IREE_STATUS_INTERNAL,
+                            "failed to invoke main compiler pipeline");
+  }
+
+  // Flush the output to the memory buffer.
+  if (auto error = ireeCompilerInvocationOutputVMBytecode(state.invocation,
+                                                          state.output)) {
+    return state.emitError(error, IREE_STATUS_INTERNAL,
+                           "emitting output VM module binary");
+  }
+
+  // Get a raw host pointer to the output that we can pass to the runtime.
+  void* binary_data = nullptr;
+  uint64_t binary_size = 0;
+  if (auto error = ireeCompilerOutputMapMemory(state.output, &binary_data,
+                                               &binary_size)) {
+    return state.emitError(error, IREE_STATUS_INTERNAL,
+                           "mapping output buffer");
+  }
+
+  // Hosting libraries can route all runtime allocations to their own allocator
+  // for statistics, isolation, or efficiency. Here we use the system
+  // malloc/free.
+  iree_allocator_t host_allocator = iree_allocator_system();
+
+  // The same VM instance should be shared across many contexts. Here we only
+  // use this once but a library would want to retain this and the devices it
+  // creates for as long as practical.
+  vm::ref<iree_vm_instance_t> instance;
+  IREE_RETURN_IF_ERROR(iree_tooling_create_instance(host_allocator, &instance),
+                       "creating instance");
+
+  // Run the compiled module using the global flags for I/O (if any).
+  // This loads the module, creates a VM context with it and any dependencies,
+  // parses inputs from flags, and routes/verifies outputs as specified. Hosting
+  // libraries should always reuse contexts if possible to amortize loading
+  // costs and carry state (variables/etc) across invocations.
+  //
+  // This returns a process exit code based on the run mode (verifying expected
+  // outputs, etc) that may be non-zero even if the status is success
+  // ("execution completed successfully but values did not match").
+  int exit_code = EXIT_SUCCESS;
+  IREE_RETURN_IF_ERROR(
+      iree_tooling_run_module_with_data(
+          instance.get(),
+          iree_make_string_view(default_device_uri.data(),
+                                default_device_uri.size()),
+          iree_make_const_byte_span(binary_data, (iree_host_size_t)binary_size),
+          host_allocator, &exit_code),
+      "running compiled module");
+  return exit_code;
 }
 
+// Parses a combined list of compiler and runtime flags.
+// Each argument list is stored in canonical argc/argv format with a trailing
+// NULL string in the storage (excluded from the count).
+class ArgParser {
+ public:
+  int compiler_argc() { return compiler_args_.size() - 1; }
+  const char** compiler_argv() {
+    return const_cast<const char**>(compiler_args_.data());
+  }
+
+  int runtime_argc() { return runtime_args_.size() - 1; }
+  char** runtime_argv() { return runtime_args_.data(); }
+
+  // Parses arguments from a raw command line argc/argv set.
+  // Returns true if parsing was successful.
+  bool Parse(int argc_raw, char** argv_raw) {
+    // Pre-process the arguments with the compiler's argument parser since it
+    // has super-powers on Windows and must work on the default main arguments.
+    ireeCompilerGetProcessCLArgs(&argc_raw,
+                                 const_cast<const char***>(&argv_raw));
+
+    // Always add the progname to both flag sets.
+    compiler_args_.push_back(argv_raw[0]);
+    runtime_args_.push_back(argv_raw[0]);
+
+    // Everything before -- goes to the runtime.
+    // Everything after -- goes to the compiler.
+    // To make it easier to form command lines in scripts we also allow
+    // prefixing flags with -Xcompiler/-Xruntime on either side of the --.
+    bool parsing_runtime_args = true;
+    for (int i = 1; i < argc_raw; ++i) {
+      char* current_arg_cstr = argv_raw[i];
+      char* next_arg_cstr =
+          argv_raw[i + 1];  // ok because list is NULL-terminated
+      auto current_arg = std::string_view(current_arg_cstr);
+      if (current_arg == "--") {
+        // Switch default parsing to compiler flags.
+        parsing_runtime_args = false;
+      } else if (current_arg == "-Xcompiler" || current_arg == "--Xcompiler") {
+        // Next arg is routed to the compiler.
+        compiler_args_.push_back(next_arg_cstr);
+      } else if (current_arg == "-Xruntime" || current_arg == "--Xruntime") {
+        // Next arg is routed to the runtime.
+        runtime_args_.push_back(next_arg_cstr);
+      } else if (starts_with("-Xcompiler,", current_arg) ||
+                 starts_with("--Xcompiler,", current_arg)) {
+        // Split and send the rest of the flag to the compiler.
+        AppendPrefixedArgs(current_arg, &compiler_args_);
+      } else if (starts_with("-Xruntime,", current_arg) ||
+                 starts_with("--Xruntime,", current_arg)) {
+        // Split and send the rest of the flag to the runtime.
+        AppendPrefixedArgs(current_arg, &runtime_args_);
+      } else {
+        // Route to either runtime or compiler arg sets based on which side of
+        // the -- we are on.
+        if (parsing_runtime_args) {
+          runtime_args_.push_back(current_arg_cstr);
+        } else {
+          compiler_args_.push_back(current_arg_cstr);
+        }
+      }
+    }
+
+    // Add nullptrs to end to match real argv behavior.
+    compiler_args_.push_back(nullptr);
+    runtime_args_.push_back(nullptr);
+
+    return true;
+  }
+
+ private:
+  // Drops the prefix from |prefixed_arg| and appends one or more to |out_args|.
+  // Example: --Xcompiler,ab=cd,ef=gh -> --ab=cd + --ef=gh
+  void AppendPrefixedArgs(std::string_view prefixed_arg,
+                          std::vector<char*>* out_args) {
+    auto append_flag_string = [&](std::string_view slice_arg) {
+      auto stable_arg = std::make_unique<std::string>("--");
+      stable_arg->append(slice_arg);
+      temp_strings_.push_back(std::move(stable_arg));
+      out_args->push_back(temp_strings_.back()->data());
+    };
+    std::string_view sub_arg = prefixed_arg.substr(prefixed_arg.find(',') + 1);
+    for (;;) {
+      size_t comma_pos = sub_arg.find_first_of(',');
+      if (comma_pos == std::string_view::npos) break;
+      append_flag_string(sub_arg.substr(0, comma_pos));
+      sub_arg = sub_arg.substr(comma_pos + 1);
+    }
+    append_flag_string(sub_arg);
+  }
+
+  std::vector<std::unique_ptr<std::string>> temp_strings_;
+  std::vector<char*> runtime_args_;
+  std::vector<char*> compiler_args_;
+};
+
 }  // namespace
 
-extern "C" int main(int argc_llvm, char** argv_llvm) {
+extern "C" int main(int argc, char** argv) {
   IREE_TRACE_SCOPE0("iree-run-mlir");
 
-  mlir::DialectRegistry registry;
-  mlir::iree_compiler::registerAllDialects(registry);
-  mlir::iree_compiler::registerHALTargetBackends();
-  mlir::iree_compiler::registerVMTargets();
-  mlir::registerBuiltinDialectTranslation(registry);
-  mlir::registerLLVMDialectTranslation(registry);
-  // Make sure command line options are registered.
-  // Flag options structs (must resolve prior to CLI parsing).
-  (void)mlir::iree_compiler::BindingOptions::FromFlags::get();
-  (void)mlir::iree_compiler::InputDialectOptions::FromFlags::get();
-  (void)mlir::iree_compiler::HighLevelOptimizationOptions::FromFlags::get();
-  (void)mlir::iree_compiler::SchedulingOptions::FromFlags::get();
-  (void)mlir::iree_compiler::IREE::HAL::TargetOptions::FromFlags::get();
-  (void)mlir::iree_compiler::IREE::VM::TargetOptions::FromFlags::get();
-  (void)mlir::iree_compiler::IREE::VM::BytecodeTargetOptions::FromFlags::get();
+  // Initialize the compiler once on startup before using any other APIs.
+  ireeCompilerGlobalInitialize();
 
-  // Register MLIRContext command-line options like
-  // -mlir-print-op-on-diagnostic.
-  mlir::registerMLIRContextCLOptions();
-  // Register assembly printer command-line options like
-  // -mlir-print-op-generic.
-  mlir::registerAsmPrinterCLOptions();
-  // Register pass manager command-line options like -mlir-print-ir-*.
-  mlir::registerPassManagerCLOptions();
-
-  // On Windows InitLLVM re-queries the command line from Windows directly and
-  // totally messes up the array.
-  llvm::setBugReportMsg(
-      "Please report issues to https://github.com/openxla/iree/issues and "
-      "include the crash backtrace.\n");
-  llvm::InitLLVM init_llvm(argc_llvm, argv_llvm);
-  llvm::cl::ParseCommandLineOptions(argc_llvm, argv_llvm);
-
-  // Consume all options after the positional filename and pass them to the IREE
-  // flag parser.
-  std::vector<char*> argv_iree = {argv_llvm[0]};
-  for (auto& run_arg : run_args_flag) {
-    if (run_arg == "--") continue;
-    argv_iree.push_back(const_cast<char*>(run_arg.c_str()));
-  }
-  int argc_iree = static_cast<int>(argv_iree.size());
-  char** argv_iree_ptr = argv_iree.data();
-  iree_flags_parse_checked(IREE_FLAGS_PARSE_MODE_DEFAULT, &argc_iree,
-                           &argv_iree_ptr);
-
-  auto status = RunFile(input_file_flag, registry);
-  if (!status.ok()) {
-    fprintf(stderr, "ERROR running file (%s):\n%s\n", input_file_flag.c_str(),
-            status.ToString().c_str());
+  // Parse full argument list and split into compiler/runtime flag sets.
+  ArgParser arg_parser;
+  if (!arg_parser.Parse(argc, argv)) {
+    ireeCompilerGlobalShutdown();
     return 1;
   }
-  return 0;
+
+  // Pass along compiler flags.
+  // Since this is a command line tool we initialize the global compiler
+  // command line environment prior to processing the sources.
+  // In-process/library uses would usually not do this and would set session
+  // specific arguments as needed from whatever configuration mechanisms they
+  // use (kwargs passed to python functions, etc).
+  ireeCompilerSetupGlobalCL(arg_parser.compiler_argc(),
+                            arg_parser.compiler_argv(), "iree-run-mlir",
+                            /*installSignalHandlers=*/true);
+
+  // Pass along runtime flags.
+  // Note that positional args are left in runtime_argv (after progname).
+  // Runtime flags are generally only useful in command line tools where there's
+  // a fixed set of devices, a short lifetime, a single thread, and a single
+  // context/set of modules/etc. Hosting applications can programmatically
+  // do most of what the flags do in a way that avoids the downsides of such
+  // global one-shot configuration.
+  int runtime_argc = arg_parser.runtime_argc();
+  char** runtime_argv = arg_parser.runtime_argv();
+  iree_flags_parse_checked(IREE_FLAGS_PARSE_MODE_DEFAULT, &runtime_argc,
+                           &runtime_argv);
+
+  // Ensure a source file was found.
+  if (runtime_argc != 2) {
+    fprintf(stderr,
+            "ERROR: one source MLIR file must be specified.\n"
+            "Pass either the path to a .mlir/mlirbc file or `-` to read from "
+            "stdin.\n");
+    fflush(stderr);
+    return 1;
+  }
+  const char* source_filename = runtime_argv[1];
+
+  // Sessions can be reused for many compiler invocations.
+  iree_compiler_session_t* session = ireeCompilerSessionCreate();
+
+  // The process return code is 0 for success and non-zero otherwise.
+  // We don't differentiate between compiler or runtime error codes here but
+  // could if someone found it useful.
+  int rc = EXIT_SUCCESS;
+
+  // Compile and run the provided source file and get the exit code determined
+  // based on the run mode.
+  auto status_or = CompileAndRunFile(session, source_filename);
+  if (status_or.ok()) {
+    rc = status_or.value();
+  } else {
+    rc = 2;
+    iree_status_fprint(stderr, status_or.status().get());
+    fflush(stderr);
+  }
+
+  ireeCompilerSessionDestroy(session);
+
+  // No more compiler APIs can be called after this point.
+  ireeCompilerGlobalShutdown();
+  return rc;
 }
 
 }  // namespace iree
diff --git a/tools/iree-run-module-main.c b/tools/iree-run-module-main.c
new file mode 100644
index 0000000..f18d5c5
--- /dev/null
+++ b/tools/iree-run-module-main.c
@@ -0,0 +1,57 @@
+// Copyright 2020 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "iree/base/api.h"
+#include "iree/base/internal/flags.h"
+#include "iree/base/tracing.h"
+#include "iree/tooling/context_util.h"
+#include "iree/tooling/run_module.h"
+#include "iree/vm/api.h"
+
+int main(int argc, char** argv) {
+  int exit_code = EXIT_SUCCESS;
+  IREE_TRACE_ZONE_BEGIN(z0);
+
+  // Parse command line flags.
+  iree_flags_set_usage(
+      "iree-run-module",
+      "Runs a function within a compiled IREE module and handles I/O parsing\n"
+      "and optional expected value verification/output processing. Modules\n"
+      "can be provided by file path (`--module=file.vmfb`) or read from stdin\n"
+      "(`--module=-`) and the function to execute matches the original name\n"
+      "provided to the compiler (`--function=foo` for `func.func @foo`).\n");
+  iree_flags_parse_checked(IREE_FLAGS_PARSE_MODE_DEFAULT, &argc, &argv);
+
+  // Hosting applications can provide their own allocators to pool resources or
+  // track allocation statistics related to IREE code.
+  iree_allocator_t host_allocator = iree_allocator_system();
+  // Hosting applications should reuse instances across multiple contexts that
+  // have similar composition (similar types/modules/etc). Most applications can
+  // get by with a single shared instance.
+  iree_vm_instance_t* instance = NULL;
+  iree_status_t status =
+      iree_tooling_create_instance(host_allocator, &instance);
+
+  // Utility to run the module with the command line flags. This particular
+  // method is only useful in these IREE tools that want consistent flags -
+  // a real application will need to do what this is doing with its own setup
+  // and I/O handling.
+  if (iree_status_is_ok(status)) {
+    status = iree_tooling_run_module_from_flags(instance, host_allocator,
+                                                &exit_code);
+  }
+
+  iree_vm_instance_release(instance);
+
+  if (!iree_status_is_ok(status)) {
+    iree_status_fprint(stderr, status);
+    iree_status_free(status);
+    exit_code = EXIT_FAILURE;
+  }
+
+  IREE_TRACE_ZONE_END(z0);
+  return exit_code;
+}
diff --git a/tools/iree-run-module-main.cc b/tools/iree-run-module-main.cc
deleted file mode 100644
index 0ce8750..0000000
--- a/tools/iree-run-module-main.cc
+++ /dev/null
@@ -1,241 +0,0 @@
-// Copyright 2020 The IREE Authors
-//
-// Licensed under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-
-#include <array>
-#include <cstdio>
-#include <cstdlib>
-#include <iterator>
-#include <string>
-#include <utility>
-#include <vector>
-
-#include "iree/base/api.h"
-#include "iree/base/internal/flags.h"
-#include "iree/base/tracing.h"
-#include "iree/hal/api.h"
-#include "iree/modules/hal/types.h"
-#include "iree/tooling/comparison.h"
-#include "iree/tooling/context_util.h"
-#include "iree/tooling/device_util.h"
-#include "iree/tooling/instrument_util.h"
-#include "iree/tooling/vm_util.h"
-#include "iree/vm/api.h"
-
-IREE_FLAG(string, function, "",
-          "Name of a function contained in the module specified by --module= "
-          "to run.");
-
-IREE_FLAG(bool, print_statistics, false,
-          "Prints runtime statistics to stderr on exit.");
-
-IREE_FLAG_LIST(
-    string, input,
-    "An input (a) value or (b) buffer of the format:\n"
-    "  (a) scalar value\n"
-    "     value\n"
-    "     e.g.: --input=\"3.14\"\n"
-    "  (b) buffer:\n"
-    "     [shape]xtype=[value]\n"
-    "     e.g.: --input=\"2x2xi32=1 2 3 4\"\n"
-    "Optionally, brackets may be used to separate the element values:\n"
-    "  2x2xi32=[[1 2][3 4]]\n"
-    "Raw binary files can be read to provide buffer contents:\n"
-    "  2x2xi32=@some/file.bin\n"
-    "\n"
-    "Numpy npy files from numpy.save can be read to provide 1+ values:\n"
-    "  @some.npy\n"
-    "\n"
-    "Each occurrence of the flag indicates an input in the order they were\n"
-    "specified on the command line.");
-
-IREE_FLAG_LIST(
-    string, output,
-    "Specifies how to handle an output from the invocation:\n"
-    "  `` (empty): ignore output\n"
-    "     e.g.: --output=\n"
-    "  `-`: print textual form to stdout\n"
-    "     e.g.: --output=-\n"
-    "  `@file.npy`: create/overwrite a numpy npy file and write buffer view\n"
-    "     e.g.: --output=@file.npy\n"
-    "  `+file.npy`: create/append a numpy npy file and write buffer view\n"
-    "     e.g.: --output=+file.npy\n"
-    "\n"
-    "Numpy npy files can be read in Python using numpy.load, for example an\n"
-    "invocation producing two outputs can be concatenated as:\n"
-    "    --output=@file.npy --output=+file.npy\n"
-    "And then loaded in Python by reading from the same file:\n"
-    "  with open('file.npy', 'rb') as f:\n"
-    "    print(numpy.load(f))\n"
-    "    print(numpy.load(f))\n"
-    "\n"
-    "Each occurrence of the flag indicates an output in the order they were\n"
-    "specified on the command line.");
-
-IREE_FLAG_LIST(string, expected_output,
-               "An expected function output following the same format as "
-               "--input. When present the results of the "
-               "invocation will be compared against these values and the "
-               "tool will return non-zero if any differ. If the value of a "
-               "particular output is not of interest provide `(ignored)`.");
-
-IREE_FLAG(int32_t, output_max_element_count, 1024,
-          "Prints up to the maximum number of elements of output tensors, "
-          "eliding the remainder.");
-
-namespace iree {
-namespace {
-
-iree_status_t Run(int* out_exit_code) {
-  IREE_TRACE_SCOPE0("iree-run-module");
-
-  iree_allocator_t host_allocator = iree_allocator_system();
-  vm::ref<iree_vm_instance_t> instance;
-  IREE_RETURN_IF_ERROR(iree_tooling_create_instance(host_allocator, &instance),
-                       "creating instance");
-
-  iree_tooling_module_list_t module_list;
-  iree_tooling_module_list_initialize(&module_list);
-  IREE_RETURN_IF_ERROR(iree_tooling_load_modules_from_flags(
-      instance.get(), host_allocator, &module_list));
-
-  vm::ref<iree_vm_context_t> context;
-  vm::ref<iree_hal_device_t> device;
-  vm::ref<iree_hal_allocator_t> device_allocator;
-  IREE_RETURN_IF_ERROR(iree_tooling_create_context_from_flags(
-      instance.get(), module_list.count, module_list.values,
-      /*default_device_uri=*/iree_string_view_empty(), host_allocator, &context,
-      &device, &device_allocator));
-
-  std::string function_name = std::string(FLAG_function);
-  iree_vm_function_t function;
-  if (function_name.empty()) {
-    IREE_RETURN_IF_ERROR(iree_tooling_find_single_exported_function(
-        iree_tooling_module_list_back(&module_list), &function));
-  } else {
-    IREE_RETURN_IF_ERROR(
-        iree_vm_module_lookup_function_by_name(
-            iree_tooling_module_list_back(&module_list),
-            IREE_VM_FUNCTION_LINKAGE_EXPORT,
-            iree_string_view_t{function_name.data(), function_name.size()},
-            &function),
-        "looking up function '%s'", function_name.c_str());
-  }
-
-  vm::ref<iree_vm_list_t> inputs;
-  IREE_RETURN_IF_ERROR(iree_tooling_parse_to_variant_list(
-      device_allocator.get(), FLAG_input_list().values, FLAG_input_list().count,
-      host_allocator, &inputs));
-
-  // If the function is async add fences so we can invoke it synchronously.
-  vm::ref<iree_hal_fence_t> finish_fence;
-  IREE_RETURN_IF_ERROR(iree_tooling_append_async_fence_inputs(
-      inputs.get(), &function, device.get(), /*wait_fence=*/NULL,
-      &finish_fence));
-
-  vm::ref<iree_vm_list_t> outputs;
-  IREE_RETURN_IF_ERROR(iree_vm_list_create(iree_vm_make_undefined_type_def(),
-                                           16, host_allocator, &outputs));
-
-  printf("EXEC @%s\n", function_name.c_str());
-
-  IREE_RETURN_IF_ERROR(iree_hal_begin_profiling_from_flags(device.get()));
-
-  IREE_RETURN_IF_ERROR(
-      iree_vm_invoke(context.get(), function, IREE_VM_INVOCATION_FLAG_NONE,
-                     /*policy=*/nullptr, inputs.get(), outputs.get(),
-                     host_allocator),
-      "invoking function '%s'", function_name.c_str());
-
-  // If the function is async we need to wait for it to complete.
-  if (finish_fence) {
-    IREE_RETURN_IF_ERROR(
-        iree_hal_fence_wait(finish_fence.get(), iree_infinite_timeout()));
-  }
-
-  IREE_RETURN_IF_ERROR(iree_hal_end_profiling_from_flags(device.get()));
-
-  if (FLAG_expected_output_list().count == 0) {
-    if (FLAG_output_list().count == 0) {
-      IREE_RETURN_IF_ERROR(
-          iree_tooling_variant_list_fprint(
-              IREE_SV("result"), outputs.get(),
-              (iree_host_size_t)FLAG_output_max_element_count, stdout),
-          "printing results");
-    } else {
-      IREE_RETURN_IF_ERROR(
-          iree_tooling_output_variant_list(
-              outputs.get(), FLAG_output_list().values,
-              FLAG_output_list().count,
-              (iree_host_size_t)FLAG_output_max_element_count, stdout),
-          "outputting results");
-    }
-  } else {
-    // Parse expected list into host-local memory that we can easily access.
-    // Note that we return a status here as this can fail on user inputs.
-    vm::ref<iree_hal_allocator_t> heap_allocator;
-    IREE_RETURN_IF_ERROR(iree_hal_allocator_create_heap(
-        IREE_SV("heap"), host_allocator, host_allocator, &heap_allocator));
-    vm::ref<iree_vm_list_t> expected_list;
-    IREE_RETURN_IF_ERROR(iree_tooling_parse_to_variant_list(
-        heap_allocator.get(), FLAG_expected_output_list().values,
-        FLAG_expected_output_list().count, host_allocator, &expected_list));
-
-    // Compare expected vs actual lists and output diffs.
-    bool did_match = iree_tooling_compare_variant_lists(
-        expected_list.get(), outputs.get(), host_allocator, stdout);
-    if (did_match) {
-      printf("[SUCCESS] all function outputs matched their expected values.\n");
-    }
-    *out_exit_code = did_match ? EXIT_SUCCESS : EXIT_FAILURE;
-  }
-
-  // Grab any instrumentation data present in the module and write it to disk.
-  IREE_RETURN_IF_ERROR(
-      iree_tooling_process_instrument_data(context.get(), host_allocator));
-
-  // Release resources before gathering statistics.
-  inputs.reset();
-  outputs.reset();
-  iree_tooling_module_list_reset(&module_list);
-  context.reset();
-
-  if (device_allocator && FLAG_print_statistics) {
-    IREE_IGNORE_ERROR(
-        iree_hal_allocator_statistics_fprint(stderr, device_allocator.get()));
-  }
-
-  device_allocator.reset();
-  device.reset();
-  instance.reset();
-  return iree_ok_status();
-}
-
-}  // namespace
-
-extern "C" int main(int argc, char** argv) {
-  iree_flags_parse_checked(IREE_FLAGS_PARSE_MODE_DEFAULT, &argc, &argv);
-  if (argc > 1) {
-    // Avoid iree-run-module spinning endlessly on stdin if the user uses single
-    // dashes for flags.
-    printf(
-        "[ERROR] unexpected positional argument (expected none)."
-        " Did you use pass a flag with a single dash ('-')?"
-        " Use '--' instead.\n");
-    return 1;
-  }
-
-  int exit_code = EXIT_SUCCESS;
-  iree_status_t status = Run(&exit_code);
-  if (!iree_status_is_ok(status)) {
-    iree_status_fprint(stderr, status);
-    iree_status_free(status);
-    return EXIT_FAILURE;
-  }
-
-  return exit_code;
-}
-
-}  // namespace iree
diff --git a/tools/test/iree-run-mlir.mlir b/tools/test/iree-run-mlir.mlir
index f899a95..47bcbe2 100644
--- a/tools/test/iree-run-mlir.mlir
+++ b/tools/test/iree-run-mlir.mlir
@@ -1,6 +1,6 @@
-// RUN: (iree-run-mlir --iree-hal-target-backends=vmvx %s --input=f32=-2) | FileCheck %s
-// RUN: [[ $IREE_VULKAN_DISABLE == 1 ]] || (iree-run-mlir --iree-hal-target-backends=vulkan-spirv %s --input=f32=-2 | FileCheck %s)
-// RUN: iree-run-mlir --iree-hal-target-backends=llvm-cpu %s --input=f32=-2 | FileCheck %s
+// RUN: (iree-run-mlir --Xcompiler,iree-hal-target-backends=vmvx %s --input=f32=-2) | FileCheck %s
+// RUN: iree-run-mlir --Xcompiler,iree-hal-target-backends=llvm-cpu %s --input=f32=-2 | FileCheck %s
+// RUN: [[ $IREE_VULKAN_DISABLE == 1 ]] || (iree-run-mlir --Xcompiler,iree-hal-target-backends=vulkan-spirv %s --input=f32=-2 | FileCheck %s)
 
 // CHECK-LABEL: EXEC @abs
 func.func @abs(%input : tensor<f32>) -> (tensor<f32>) {
diff --git a/tools/test/multiple_args.mlir b/tools/test/multiple_args.mlir
index d00f677..e3a1cc3 100644
--- a/tools/test/multiple_args.mlir
+++ b/tools/test/multiple_args.mlir
@@ -1,5 +1,5 @@
 // RUN: iree-compile --iree-hal-target-backends=vmvx %s | iree-run-module --module=- --function=multi_input --input="2xi32=[1 2]" --input="2xi32=[3 4]" | FileCheck %s
-// RUN: iree-run-mlir --iree-hal-target-backends=vmvx %s --input="2xi32=[1 2]" --input="2xi32=[3 4]" | FileCheck %s
+// RUN: iree-run-mlir --Xcompiler,iree-hal-target-backends=vmvx %s --input="2xi32=[1 2]" --input="2xi32=[3 4]" | FileCheck %s
 // RUN: iree-compile --iree-hal-target-backends=vmvx %s | iree-benchmark-module --device=local-task --module=- --function=multi_input --input="2xi32=[1 2]" --input="2xi32=[3 4]" | FileCheck --check-prefix=BENCHMARK %s
 
 // BENCHMARK-LABEL: BM_multi_input
diff --git a/tools/test/repeated_return.mlir b/tools/test/repeated_return.mlir
index 19ba954..d22e5ca 100644
--- a/tools/test/repeated_return.mlir
+++ b/tools/test/repeated_return.mlir
@@ -1,6 +1,6 @@
 // RUN: (iree-compile --iree-hal-target-backends=vmvx %s | iree-run-module --module=- --function=many_tensor) | FileCheck %s
 // RUN: iree-compile --iree-hal-target-backends=vmvx %s | iree-benchmark-module --device=local-task --module=- --function=many_tensor | FileCheck --check-prefix=BENCHMARK %s
-// RUN: iree-run-mlir --iree-hal-target-backends=vmvx %s | FileCheck %s
+// RUN: iree-run-mlir --Xcompiler,iree-hal-target-backends=vmvx %s | FileCheck %s
 
 // BENCHMARK-LABEL: BM_many_tensor
 // CHECK-LABEL: EXEC @many_tensor
diff --git a/tools/test/scalars.mlir b/tools/test/scalars.mlir
index 4bdf4ec..b772c8f 100644
--- a/tools/test/scalars.mlir
+++ b/tools/test/scalars.mlir
@@ -1,6 +1,6 @@
 // RUN: (iree-compile --iree-hal-target-backends=vmvx %s | iree-run-module --module=- --function=scalar --input=42) | FileCheck %s
 // RUN: iree-compile --iree-hal-target-backends=vmvx %s | iree-benchmark-module --device=local-task --module=- --function=scalar --input=42 | FileCheck --check-prefix=BENCHMARK %s
-// RUN: (iree-run-mlir --iree-hal-target-backends=vmvx %s --input=42) | FileCheck %s
+// RUN: (iree-run-mlir --Xcompiler,iree-hal-target-backends=vmvx %s --input=42) | FileCheck %s
 
 // BENCHMARK-LABEL: BM_scalar
 // CHECK-LABEL: EXEC @scalar