Simplifying iree-run-mlir by making it run only a single function. (#13149)

This matches the behavior of iree-run-module and makes further cleanup
easier. Testing multiple functions never really caught on and that's
probably a good thing.

Progress on #12715.
diff --git a/runtime/src/iree/tooling/context_util.c b/runtime/src/iree/tooling/context_util.c
index be23304..35985bf 100644
--- a/runtime/src/iree/tooling/context_util.c
+++ b/runtime/src/iree/tooling/context_util.c
@@ -495,6 +495,44 @@
   return status;
 }
 
+iree_status_t iree_tooling_find_single_exported_function(
+    iree_vm_module_t* module, iree_vm_function_t* out_function) {
+  memset(out_function, 0, sizeof(*out_function));
+  iree_vm_module_signature_t module_signature =
+      iree_vm_module_signature(module);
+  iree_host_size_t exported_functions = 0;
+  for (iree_host_size_t i = 0; i < module_signature.export_function_count;
+       ++i) {
+    iree_vm_function_t function = {0};
+    IREE_RETURN_IF_ERROR(
+        iree_vm_module_lookup_function_by_ordinal(
+            module, IREE_VM_FUNCTION_LINKAGE_EXPORT, i, &function),
+        "looking up function export %zu", i);
+    iree_string_view_t function_name = iree_vm_function_name(&function);
+    if (iree_string_view_starts_with(function_name,
+                                     iree_make_cstring_view("__")) ||
+        iree_string_view_find_char(function_name, '$', 0) !=
+            IREE_STRING_VIEW_NPOS) {
+      // Function was either internal or special; we don't want to run these
+      // as they have special ABI requirements or must only be called in
+      // specific situations (module initializers, etc).
+      continue;
+    }
+    if (exported_functions == 0) *out_function = function;
+    ++exported_functions;
+  }
+  if (exported_functions == 0) {
+    return iree_make_status(
+        IREE_STATUS_INVALID_ARGUMENT,
+        "no exported functions found in module; at least one must be present");
+  } else if (exported_functions > 1) {
+    return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
+                            "more than one exported function present; "
+                            "--function= must be specified explicitly");
+  }
+  return iree_ok_status();
+}
+
 //===----------------------------------------------------------------------===//
 // Context management
 //===----------------------------------------------------------------------===//
diff --git a/runtime/src/iree/tooling/context_util.h b/runtime/src/iree/tooling/context_util.h
index c1543c9..75a2459 100644
--- a/runtime/src/iree/tooling/context_util.h
+++ b/runtime/src/iree/tooling/context_util.h
@@ -61,16 +61,17 @@
     iree_hal_device_t** out_device,
     iree_hal_allocator_t** out_device_allocator);
 
-//===----------------------------------------------------------------------===//
-// Module loading
-//===----------------------------------------------------------------------===//
-
 // Loads modules in the order specified by the --module= flag.
 // Appends the modules to the |list|.
 iree_status_t iree_tooling_load_modules_from_flags(
     iree_vm_instance_t* instance, iree_allocator_t host_allocator,
     iree_tooling_module_list_t* list);
 
+// Returns the single exported user function from |module| in |out_function| or
+// an error if zero or more than one function are present.
+iree_status_t iree_tooling_find_single_exported_function(
+    iree_vm_module_t* module, iree_vm_function_t* out_function);
+
 //===----------------------------------------------------------------------===//
 // Context management
 //===----------------------------------------------------------------------===//
diff --git a/tests/e2e/regression/fill_i64.mlir b/tests/e2e/regression/fill_i64.mlir
index 12fe2fa..aaa45e4 100644
--- a/tests/e2e/regression/fill_i64.mlir
+++ b/tests/e2e/regression/fill_i64.mlir
@@ -1,5 +1,5 @@
 // RUN: iree-run-mlir --iree-hal-target-backends=llvm-cpu %s --input=2x3xi64 | FileCheck %s
-// RUN: [[ $IREE_VMVX_DISABLE == 1 ]]    || (iree-run-mlir --iree-hal-target-backends=vmvx %s --input=2x3xi64 | FileCheck %s)
+// RUN: [[ $IREE_VMVX_DISABLE == 1 ]] || (iree-run-mlir --iree-hal-target-backends=vmvx %s --input=2x3xi64 | FileCheck %s)
 
 // CHECK: EXEC @fill_i64
 func.func @fill_i64(%arg0: tensor<?x?xi64>) -> (tensor<?x?xi64>, tensor<?x?xi64>) {
diff --git a/tests/e2e/regression/globals.mlir b/tests/e2e/regression/globals.mlir
index e9c8c02..0c2869c 100644
--- a/tests/e2e/regression/globals.mlir
+++ b/tests/e2e/regression/globals.mlir
@@ -1,24 +1,15 @@
-// RUN: iree-run-mlir --iree-input-type=mhlo --iree-hal-target-backends=vmvx %s | FileCheck %s
-// RUN: [[ $IREE_VULKAN_DISABLE == 1 ]] || (iree-run-mlir --iree-input-type=mhlo --iree-hal-target-backends=vulkan-spirv %s | FileCheck %s)
+// RUN: iree-run-mlir --iree-hal-target-backends=vmvx %s | FileCheck %s
+// RUN: [[ $IREE_VULKAN_DISABLE == 1 ]] || (iree-run-mlir --iree-hal-target-backends=vulkan-spirv %s | FileCheck %s)
 
-module {
-  util.global private mutable @counter = dense<2.0> : tensor<f32>
+util.global private mutable @counter = dense<2.0> : tensor<f32>
 
-  // CHECK: EXEC @get_state
-  func.func @get_state() -> tensor<f32> {
-    %0 = util.global.load @counter : tensor<f32>
-    return %0 : tensor<f32>
-  }
-  // CHECK: f32=2
-
-  // CHECK: EXEC @inc
-  func.func @inc() -> tensor<f32> {
-    %0 = util.global.load @counter : tensor<f32>
-    %c1 = arith.constant dense<1.0> : tensor<f32>
-    %1 = mhlo.add %0, %c1 : tensor<f32>
-    util.global.store %1, @counter : tensor<f32>
-    %2 = util.global.load @counter : tensor<f32>
-    return %2 : tensor<f32>
-  }
-  // CHECK: f32=3
+// CHECK: EXEC @inc
+func.func @inc() -> tensor<f32> {
+  %0 = util.global.load @counter : tensor<f32>
+  %c1 = arith.constant dense<1.0> : tensor<f32>
+  %1 = arith.addf %0, %c1 : tensor<f32>
+  util.global.store %1, @counter : tensor<f32>
+  %2 = util.global.load @counter : tensor<f32>
+  return %2 : tensor<f32>
 }
+// CHECK: f32=3
diff --git a/tools/iree-run-mlir-main.cc b/tools/iree-run-mlir-main.cc
index ebd9762..06c3829 100644
--- a/tools/iree-run-mlir-main.cc
+++ b/tools/iree-run-mlir-main.cc
@@ -4,29 +4,32 @@
 // See https://llvm.org/LICENSE.txt for license information.
 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
 
-// IREE source.mlir -> execution output test runner.
-// This is meant to be called from LIT for FileCheck tests, and tries to match
-// the interface of mlir-opt (featuring -split-input-file, etc) so it's easier
-// to work with there. If you want a more generalized runner for standalone
-// precompiled IREE modules use iree-run-module.
+// IREE source.mlir -> execution output runner.
+// This is meant to be called from LIT for FileCheck tests or as a developer
+// tool to emulate what an online compiler does. It tries to match the interface
+// of iree-compile and iree-opt (featuring -split-input-file, etc) so it's
+// easy to run tests or approximate an iree-compile | iree-run-module sequence.
+// If you want a more generalized runner for standalone precompiled IREE modules
+// use iree-run-module instead.
 //
-// By default all exported functions in the module will be run in order.
-// All input values, provided via -function-inputs, will be passed to the
-// functions (this means all input signatures must match). Results from the
-// executed functions will be printed to stdout for checking.
+// If there's a single exported function that will be executed and if there are
+// multiple functions --function= can be used to specify which is executed.
+// Function inputs can be provided with --input=. Results from the executed
+// function will be printed to stdout for checking or can be written to files
+// with --output=.
 //
 // Example input:
 // // RUN: iree-run-mlir %s | FileCheck %s
 // // CHECK-LABEL: @foo
-// // CHECK: 1xf32: 2
-// func.func @foo() -> tensor<f32> {
-//   %0 = arith.constant dense<2.0> : tensor<f32>
-//   return %0 : tensor<f32>
+// // CHECK: 2xf32=[2 3]
+// func.func @foo() -> tensor<2xf32> {
+//   %0 = arith.constant dense<[2.0, 3.0]> : tensor<2xf32>
+//   return %0 : tensor<2xf32>
 // }
 //
 // Command line arguments are handled by LLVM's parser by default but -- can be
 // used to separate the compiler flags from the runtime flags, such as:
-//   iree-run-mlir --iree-hal-target-backends=vulkan-spirv -- --logtostderr
+//   iree-run-mlir --iree-hal-target-backends=llvm-cpu -- --device=local-task
 
 #include <cstdio>
 #include <cstring>
@@ -139,6 +142,10 @@
     llvm::cl::ConsumeAfter,
 };
 
+IREE_FLAG(string, function, "",
+          "Name of a function contained in the compiled module. If omitted\n"
+          "and there's a single exported function that will be run instead.");
+
 IREE_FLAG_LIST(
     string, input,
     "An input (a) value or (b) buffer of the format:\n"
@@ -436,21 +443,22 @@
     return OkStatus();
   }
 
-  // Evaluate all exported functions.
-  auto run_function = [&](int ordinal) -> Status {
-    iree_vm_function_t function;
-    IREE_RETURN_IF_ERROR(iree_vm_module_lookup_function_by_ordinal(
+  // Choose which function to run - either the one specified in the flag or the
+  // only exported non-internal function.
+  iree_vm_function_t function = {0};
+  if (strlen(FLAG_function) == 0) {
+    IREE_RETURN_IF_ERROR(iree_tooling_find_single_exported_function(
+        main_module.get(), &function));
+  } else {
+    IREE_RETURN_IF_ERROR(iree_vm_module_lookup_function_by_name(
                              main_module.get(), IREE_VM_FUNCTION_LINKAGE_EXPORT,
-                             ordinal, &function),
-                         "looking up function export %d", ordinal);
+                             iree_make_cstring_view(FLAG_function), &function),
+                         "looking up function '%s'", FLAG_function);
+  }
+
+  // Evaluate all exported functions.
+  auto run_function = [&](iree_vm_function_t function) -> Status {
     iree_string_view_t function_name = iree_vm_function_name(&function);
-    if (iree_string_view_starts_with(function_name,
-                                     iree_make_cstring_view("__")) ||
-        iree_string_view_find_char(function_name, '$', 0) !=
-            IREE_STRING_VIEW_NPOS) {
-      // Skip internal or special functions.
-      return OkStatus();
-    }
 
     // Create the context we'll use for this (ensuring that we can't interfere
     // with other running evaluations, such as when in a multithreaded test
@@ -470,7 +478,8 @@
     IREE_RETURN_IF_ERROR(
         EvaluateFunction(context.get(), device.get(), device_allocator.get(),
                          function, function_name),
-        "evaluating export function %d", ordinal);
+        "evaluating export function %.*s", (int)function_name.size,
+        function_name.data);
 
     IREE_RETURN_IF_ERROR(iree_hal_end_profiling_from_flags(device.get()));
 
@@ -479,16 +488,7 @@
     device.reset();
     return OkStatus();
   };
-
-  Status evaluate_status = OkStatus();
-  auto module_signature = iree_vm_module_signature(main_module.get());
-  for (iree_host_size_t i = 0; i < module_signature.export_function_count;
-       ++i) {
-    evaluate_status = run_function(i);
-    if (!evaluate_status.ok()) {
-      break;
-    }
-  }
+  Status evaluate_status = run_function(function);
 
   main_module.reset();
   iree_tooling_module_list_reset(&module_list);
diff --git a/tools/iree-run-module-main.cc b/tools/iree-run-module-main.cc
index 120f3ba..0ce8750 100644
--- a/tools/iree-run-module-main.cc
+++ b/tools/iree-run-module-main.cc
@@ -112,8 +112,8 @@
   std::string function_name = std::string(FLAG_function);
   iree_vm_function_t function;
   if (function_name.empty()) {
-    return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
-                            "no --function= specified");
+    IREE_RETURN_IF_ERROR(iree_tooling_find_single_exported_function(
+        iree_tooling_module_list_back(&module_list), &function));
   } else {
     IREE_RETURN_IF_ERROR(
         iree_vm_module_lookup_function_by_name(
@@ -124,8 +124,6 @@
         "looking up function '%s'", function_name.c_str());
   }
 
-  IREE_RETURN_IF_ERROR(iree_hal_begin_profiling_from_flags(device.get()));
-
   vm::ref<iree_vm_list_t> inputs;
   IREE_RETURN_IF_ERROR(iree_tooling_parse_to_variant_list(
       device_allocator.get(), FLAG_input_list().values, FLAG_input_list().count,
@@ -142,6 +140,9 @@
                                            16, host_allocator, &outputs));
 
   printf("EXEC @%s\n", function_name.c_str());
+
+  IREE_RETURN_IF_ERROR(iree_hal_begin_profiling_from_flags(device.get()));
+
   IREE_RETURN_IF_ERROR(
       iree_vm_invoke(context.get(), function, IREE_VM_INVOCATION_FLAG_NONE,
                      /*policy=*/nullptr, inputs.get(), outputs.get(),