Simplifying iree-run-mlir by making it run only a single function. (#13149)
This matches the behavior of iree-run-module and makes further cleanup
easier. Testing multiple functions never really caught on and that's
probably a good thing.
Progress on #12715.
diff --git a/runtime/src/iree/tooling/context_util.c b/runtime/src/iree/tooling/context_util.c
index be23304..35985bf 100644
--- a/runtime/src/iree/tooling/context_util.c
+++ b/runtime/src/iree/tooling/context_util.c
@@ -495,6 +495,44 @@
return status;
}
+iree_status_t iree_tooling_find_single_exported_function(
+ iree_vm_module_t* module, iree_vm_function_t* out_function) {
+ memset(out_function, 0, sizeof(*out_function));
+ iree_vm_module_signature_t module_signature =
+ iree_vm_module_signature(module);
+ iree_host_size_t exported_functions = 0;
+ for (iree_host_size_t i = 0; i < module_signature.export_function_count;
+ ++i) {
+ iree_vm_function_t function = {0};
+ IREE_RETURN_IF_ERROR(
+ iree_vm_module_lookup_function_by_ordinal(
+ module, IREE_VM_FUNCTION_LINKAGE_EXPORT, i, &function),
+ "looking up function export %zu", i);
+ iree_string_view_t function_name = iree_vm_function_name(&function);
+ if (iree_string_view_starts_with(function_name,
+ iree_make_cstring_view("__")) ||
+ iree_string_view_find_char(function_name, '$', 0) !=
+ IREE_STRING_VIEW_NPOS) {
+ // Function was either internal or special; we don't want to run these
+ // as they have special ABI requirements or must only be called in
+ // specific situations (module initializers, etc).
+ continue;
+ }
+ if (exported_functions == 0) *out_function = function;
+ ++exported_functions;
+ }
+ if (exported_functions == 0) {
+ return iree_make_status(
+ IREE_STATUS_INVALID_ARGUMENT,
+ "no exported functions found in module; at least one must be present");
+ } else if (exported_functions > 1) {
+ return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
+ "more than one exported function present; "
+ "--function= must be specified explicitly");
+ }
+ return iree_ok_status();
+}
+
//===----------------------------------------------------------------------===//
// Context management
//===----------------------------------------------------------------------===//
diff --git a/runtime/src/iree/tooling/context_util.h b/runtime/src/iree/tooling/context_util.h
index c1543c9..75a2459 100644
--- a/runtime/src/iree/tooling/context_util.h
+++ b/runtime/src/iree/tooling/context_util.h
@@ -61,16 +61,17 @@
iree_hal_device_t** out_device,
iree_hal_allocator_t** out_device_allocator);
-//===----------------------------------------------------------------------===//
-// Module loading
-//===----------------------------------------------------------------------===//
-
// Loads modules in the order specified by the --module= flag.
// Appends the modules to the |list|.
iree_status_t iree_tooling_load_modules_from_flags(
iree_vm_instance_t* instance, iree_allocator_t host_allocator,
iree_tooling_module_list_t* list);
+// Returns the single exported user function from |module| in |out_function| or
+// an error if zero or more than one function are present.
+iree_status_t iree_tooling_find_single_exported_function(
+ iree_vm_module_t* module, iree_vm_function_t* out_function);
+
//===----------------------------------------------------------------------===//
// Context management
//===----------------------------------------------------------------------===//
diff --git a/tests/e2e/regression/fill_i64.mlir b/tests/e2e/regression/fill_i64.mlir
index 12fe2fa..aaa45e4 100644
--- a/tests/e2e/regression/fill_i64.mlir
+++ b/tests/e2e/regression/fill_i64.mlir
@@ -1,5 +1,5 @@
// RUN: iree-run-mlir --iree-hal-target-backends=llvm-cpu %s --input=2x3xi64 | FileCheck %s
-// RUN: [[ $IREE_VMVX_DISABLE == 1 ]] || (iree-run-mlir --iree-hal-target-backends=vmvx %s --input=2x3xi64 | FileCheck %s)
+// RUN: [[ $IREE_VMVX_DISABLE == 1 ]] || (iree-run-mlir --iree-hal-target-backends=vmvx %s --input=2x3xi64 | FileCheck %s)
// CHECK: EXEC @fill_i64
func.func @fill_i64(%arg0: tensor<?x?xi64>) -> (tensor<?x?xi64>, tensor<?x?xi64>) {
diff --git a/tests/e2e/regression/globals.mlir b/tests/e2e/regression/globals.mlir
index e9c8c02..0c2869c 100644
--- a/tests/e2e/regression/globals.mlir
+++ b/tests/e2e/regression/globals.mlir
@@ -1,24 +1,15 @@
-// RUN: iree-run-mlir --iree-input-type=mhlo --iree-hal-target-backends=vmvx %s | FileCheck %s
-// RUN: [[ $IREE_VULKAN_DISABLE == 1 ]] || (iree-run-mlir --iree-input-type=mhlo --iree-hal-target-backends=vulkan-spirv %s | FileCheck %s)
+// RUN: iree-run-mlir --iree-hal-target-backends=vmvx %s | FileCheck %s
+// RUN: [[ $IREE_VULKAN_DISABLE == 1 ]] || (iree-run-mlir --iree-hal-target-backends=vulkan-spirv %s | FileCheck %s)
-module {
- util.global private mutable @counter = dense<2.0> : tensor<f32>
+util.global private mutable @counter = dense<2.0> : tensor<f32>
- // CHECK: EXEC @get_state
- func.func @get_state() -> tensor<f32> {
- %0 = util.global.load @counter : tensor<f32>
- return %0 : tensor<f32>
- }
- // CHECK: f32=2
-
- // CHECK: EXEC @inc
- func.func @inc() -> tensor<f32> {
- %0 = util.global.load @counter : tensor<f32>
- %c1 = arith.constant dense<1.0> : tensor<f32>
- %1 = mhlo.add %0, %c1 : tensor<f32>
- util.global.store %1, @counter : tensor<f32>
- %2 = util.global.load @counter : tensor<f32>
- return %2 : tensor<f32>
- }
- // CHECK: f32=3
+// CHECK: EXEC @inc
+func.func @inc() -> tensor<f32> {
+ %0 = util.global.load @counter : tensor<f32>
+ %c1 = arith.constant dense<1.0> : tensor<f32>
+ %1 = arith.addf %0, %c1 : tensor<f32>
+ util.global.store %1, @counter : tensor<f32>
+ %2 = util.global.load @counter : tensor<f32>
+ return %2 : tensor<f32>
}
+// CHECK: f32=3
diff --git a/tools/iree-run-mlir-main.cc b/tools/iree-run-mlir-main.cc
index ebd9762..06c3829 100644
--- a/tools/iree-run-mlir-main.cc
+++ b/tools/iree-run-mlir-main.cc
@@ -4,29 +4,32 @@
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-// IREE source.mlir -> execution output test runner.
-// This is meant to be called from LIT for FileCheck tests, and tries to match
-// the interface of mlir-opt (featuring -split-input-file, etc) so it's easier
-// to work with there. If you want a more generalized runner for standalone
-// precompiled IREE modules use iree-run-module.
+// IREE source.mlir -> execution output runner.
+// This is meant to be called from LIT for FileCheck tests or as a developer
+// tool to emulate what an online compiler does. It tries to match the interface
+// of iree-compile and iree-opt (featuring -split-input-file, etc) so it's
+// easy to run tests or approximate an iree-compile | iree-run-module sequence.
+// If you want a more generalized runner for standalone precompiled IREE modules
+// use iree-run-module instead.
//
-// By default all exported functions in the module will be run in order.
-// All input values, provided via -function-inputs, will be passed to the
-// functions (this means all input signatures must match). Results from the
-// executed functions will be printed to stdout for checking.
+// If there's a single exported function that will be executed and if there are
+// multiple functions --function= can be used to specify which is executed.
+// Function inputs can be provided with --input=. Results from the executed
+// function will be printed to stdout for checking or can be written to files
+// with --output=.
//
// Example input:
// // RUN: iree-run-mlir %s | FileCheck %s
// // CHECK-LABEL: @foo
-// // CHECK: 1xf32: 2
-// func.func @foo() -> tensor<f32> {
-// %0 = arith.constant dense<2.0> : tensor<f32>
-// return %0 : tensor<f32>
+// // CHECK: 2xf32=[2 3]
+// func.func @foo() -> tensor<2xf32> {
+// %0 = arith.constant dense<[2.0, 3.0]> : tensor<2xf32>
+// return %0 : tensor<2xf32>
// }
//
// Command line arguments are handled by LLVM's parser by default but -- can be
// used to separate the compiler flags from the runtime flags, such as:
-// iree-run-mlir --iree-hal-target-backends=vulkan-spirv -- --logtostderr
+// iree-run-mlir --iree-hal-target-backends=llvm-cpu -- --device=local-task
#include <cstdio>
#include <cstring>
@@ -139,6 +142,10 @@
llvm::cl::ConsumeAfter,
};
+IREE_FLAG(string, function, "",
+ "Name of a function contained in the compiled module. If omitted\n"
+ "and there's a single exported function that will be run instead.");
+
IREE_FLAG_LIST(
string, input,
"An input (a) value or (b) buffer of the format:\n"
@@ -436,21 +443,22 @@
return OkStatus();
}
- // Evaluate all exported functions.
- auto run_function = [&](int ordinal) -> Status {
- iree_vm_function_t function;
- IREE_RETURN_IF_ERROR(iree_vm_module_lookup_function_by_ordinal(
+ // Choose which function to run - either the one specified in the flag or the
+ // only exported non-internal function.
+ iree_vm_function_t function = {0};
+ if (strlen(FLAG_function) == 0) {
+ IREE_RETURN_IF_ERROR(iree_tooling_find_single_exported_function(
+ main_module.get(), &function));
+ } else {
+ IREE_RETURN_IF_ERROR(iree_vm_module_lookup_function_by_name(
main_module.get(), IREE_VM_FUNCTION_LINKAGE_EXPORT,
- ordinal, &function),
- "looking up function export %d", ordinal);
+ iree_make_cstring_view(FLAG_function), &function),
+ "looking up function '%s'", FLAG_function);
+ }
+
+ // Evaluate all exported functions.
+ auto run_function = [&](iree_vm_function_t function) -> Status {
iree_string_view_t function_name = iree_vm_function_name(&function);
- if (iree_string_view_starts_with(function_name,
- iree_make_cstring_view("__")) ||
- iree_string_view_find_char(function_name, '$', 0) !=
- IREE_STRING_VIEW_NPOS) {
- // Skip internal or special functions.
- return OkStatus();
- }
// Create the context we'll use for this (ensuring that we can't interfere
// with other running evaluations, such as when in a multithreaded test
@@ -470,7 +478,8 @@
IREE_RETURN_IF_ERROR(
EvaluateFunction(context.get(), device.get(), device_allocator.get(),
function, function_name),
- "evaluating export function %d", ordinal);
+ "evaluating export function %.*s", (int)function_name.size,
+ function_name.data);
IREE_RETURN_IF_ERROR(iree_hal_end_profiling_from_flags(device.get()));
@@ -479,16 +488,7 @@
device.reset();
return OkStatus();
};
-
- Status evaluate_status = OkStatus();
- auto module_signature = iree_vm_module_signature(main_module.get());
- for (iree_host_size_t i = 0; i < module_signature.export_function_count;
- ++i) {
- evaluate_status = run_function(i);
- if (!evaluate_status.ok()) {
- break;
- }
- }
+ Status evaluate_status = run_function(function);
main_module.reset();
iree_tooling_module_list_reset(&module_list);
diff --git a/tools/iree-run-module-main.cc b/tools/iree-run-module-main.cc
index 120f3ba..0ce8750 100644
--- a/tools/iree-run-module-main.cc
+++ b/tools/iree-run-module-main.cc
@@ -112,8 +112,8 @@
std::string function_name = std::string(FLAG_function);
iree_vm_function_t function;
if (function_name.empty()) {
- return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
- "no --function= specified");
+ IREE_RETURN_IF_ERROR(iree_tooling_find_single_exported_function(
+ iree_tooling_module_list_back(&module_list), &function));
} else {
IREE_RETURN_IF_ERROR(
iree_vm_module_lookup_function_by_name(
@@ -124,8 +124,6 @@
"looking up function '%s'", function_name.c_str());
}
- IREE_RETURN_IF_ERROR(iree_hal_begin_profiling_from_flags(device.get()));
-
vm::ref<iree_vm_list_t> inputs;
IREE_RETURN_IF_ERROR(iree_tooling_parse_to_variant_list(
device_allocator.get(), FLAG_input_list().values, FLAG_input_list().count,
@@ -142,6 +140,9 @@
16, host_allocator, &outputs));
printf("EXEC @%s\n", function_name.c_str());
+
+ IREE_RETURN_IF_ERROR(iree_hal_begin_profiling_from_flags(device.get()));
+
IREE_RETURN_IF_ERROR(
iree_vm_invoke(context.get(), function, IREE_VM_INVOCATION_FLAG_NONE,
/*policy=*/nullptr, inputs.get(), outputs.get(),