A test runner for IREE modules
This is still just a prototype. It's not ready to take over for all our tests as it's still missing nice per-test logging, infra to make it difficult to accidentally check in a no-op test, and something less hacky than relying on status propagation.
PiperOrigin-RevId: 296287006
diff --git a/iree/modules/check/BUILD b/iree/modules/check/BUILD
index ba91384..7d36b4c 100644
--- a/iree/modules/check/BUILD
+++ b/iree/modules/check/BUILD
@@ -13,12 +13,22 @@
# limitations under the License.
load("//iree/tools:compilation.bzl", "iree_bytecode_module")
+load("//iree:build_defs.oss.bzl", "PLATFORM_VULKAN_DEPS")
package(
default_visibility = ["//visibility:public"],
licenses = ["notice"], # Apache 2.0
)
+# Driver modules that register themselves at link time.
+IREE_DRIVER_MODULES = [
+ "//iree/hal/interpreter:interpreter_driver_module",
+ # TODO(b/142004903): enable when Dawn HAL implementation is functional
+ # "//iree/hal/dawn:dawn_driver_module",
+ "//iree/hal/vmla:vmla_driver_module",
+ "//iree/hal/vulkan:vulkan_driver_module",
+]
+
iree_bytecode_module(
name = "check_test_module",
src = "check_test.mlir",
@@ -46,6 +56,25 @@
],
)
+cc_binary(
+ name = "iree-check-module",
+ srcs = ["check_module_main.cc"],
+ deps = [
+ ":native_module",
+ "//iree/base:api",
+ "//iree/tools:vm_util",
+ "@com_google_absl//absl/flags:flag",
+ "@com_google_absl//absl/strings",
+ "//iree/base:api_util",
+ "//iree/base:file_io",
+ "//iree/base:init",
+ "//iree/base:source_location",
+ "//iree/base:status",
+ "//iree/modules/hal",
+ "//iree/vm:bytecode_module",
+ ] + PLATFORM_VULKAN_DEPS + IREE_DRIVER_MODULES,
+)
+
cc_library(
name = "native_module",
srcs = ["native_module.cc"],
diff --git a/iree/modules/check/CMakeLists.txt b/iree/modules/check/CMakeLists.txt
index 998a543..01f9651 100644
--- a/iree/modules/check/CMakeLists.txt
+++ b/iree/modules/check/CMakeLists.txt
@@ -13,6 +13,7 @@
# limitations under the License.
add_subdirectory(dialect)
+add_subdirectory(test)
iree_bytecode_module(
NAME
@@ -48,6 +49,34 @@
iree::vm::ref
)
+iree_cc_binary(
+ NAME
+ iree-check-module
+ SRCS
+ "check_module_main.cc"
+ OUT
+ iree-check-module
+ DEPS
+ ::native_module
+ absl::flags
+ absl::strings
+ iree::base::api
+ iree::base::api_util
+ iree::base::file_io
+ iree::base::init
+ iree::base::source_location
+ iree::base::status
+ # TODO(marbre): Add PLATFORM_VULKAN_DEPS
+ iree::hal::interpreter::interpreter_driver_module
+ iree::hal::vmla::vmla_driver_module
+ iree::hal::vulkan::vulkan_driver_module
+ iree::modules::hal
+ iree::tools::vm_util
+ iree::vm::bytecode_module
+)
+add_executable(iree-check-module ALIAS iree_modules_check_iree-check-module)
+
+
iree_cc_library(
NAME
native_module
diff --git a/iree/modules/check/check_module_main.cc b/iree/modules/check/check_module_main.cc
new file mode 100644
index 0000000..330ab28
--- /dev/null
+++ b/iree/modules/check/check_module_main.cc
@@ -0,0 +1,170 @@
+// Copyright 2020 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <iostream>
+
+#include "absl/flags/flag.h"
+#include "absl/strings/string_view.h"
+#include "iree/base/api.h"
+#include "iree/base/api_util.h"
+#include "iree/base/file_io.h"
+#include "iree/base/init.h"
+#include "iree/base/source_location.h"
+#include "iree/base/status.h"
+#include "iree/modules/check/native_module.h"
+#include "iree/modules/hal/hal_module.h"
+#include "iree/tools/vm_util.h"
+#include "iree/vm/bytecode_module.h"
+
+ABSL_FLAG(std::string, input_file, "-",
+ "File containing the module to load that contains the entry "
+ "function. Defaults to stdin.");
+
+ABSL_FLAG(std::string, driver, "interpreter", "Backend driver to use.");
+
+ABSL_FLAG(
+ bool, expect_failure, false,
+ "Whether running module is expected to fail. If set, failing "
+ "statuses from function evaluation are logged and ignored and all "
+ "evaluations succeeding is considered an error and will return a failure.");
+
+namespace iree {
+namespace {
+
+StatusOr<std::string> GetModuleContentsFromFlags() {
+ auto input_file = absl::GetFlag(FLAGS_input_file);
+ std::string contents;
+ if (input_file == "-") {
+ contents = std::string{std::istreambuf_iterator<char>(std::cin),
+ std::istreambuf_iterator<char>()};
+ } else {
+ ASSIGN_OR_RETURN(contents, file_io::GetFileContents(input_file));
+ }
+ return contents;
+}
+
+Status Run() {
+ RETURN_IF_ERROR(FromApiStatus(iree_hal_module_register_types(), IREE_LOC))
+ << "registering HAL types";
+ iree_vm_instance_t* instance = nullptr;
+ RETURN_IF_ERROR(FromApiStatus(
+ iree_vm_instance_create(IREE_ALLOCATOR_SYSTEM, &instance), IREE_LOC))
+ << "creating instance";
+
+ ASSIGN_OR_RETURN(auto module_data, GetModuleContentsFromFlags());
+ iree_vm_module_t* input_module = nullptr;
+ RETURN_IF_ERROR(LoadBytecodeModule(module_data, &input_module));
+
+ iree_hal_device_t* device = nullptr;
+ RETURN_IF_ERROR(CreateDevice(absl::GetFlag(FLAGS_driver), &device));
+ iree_vm_module_t* hal_module = nullptr;
+ RETURN_IF_ERROR(CreateHalModule(device, &hal_module));
+ iree_vm_module_t* check_module = nullptr;
+ check_native_module_create(IREE_ALLOCATOR_SYSTEM, &check_module);
+
+ auto run_function = [&](int ordinal) -> Status {
+ iree_vm_function_t function;
+ RETURN_IF_ERROR(FromApiStatus(
+ iree_vm_module_lookup_function_by_ordinal(
+ input_module, IREE_VM_FUNCTION_LINKAGE_EXPORT, ordinal, &function),
+ IREE_LOC))
+ << "Looking up function export " << ordinal;
+ iree_string_view_t function_name_iree_sv = iree_vm_function_name(&function);
+ // TODO(gcmn): Implicit conversion from iree to absl string view.
+ auto function_name = absl::string_view(function_name_iree_sv.data,
+ function_name_iree_sv.size);
+
+ RETURN_IF_ERROR(ValidateFunctionAbi(function));
+ ASSIGN_OR_RETURN(auto input_descs, ParseInputSignature(function));
+ ASSIGN_OR_RETURN(auto output_descs, ParseOutputSignature(function));
+ if (!input_descs.empty() || !output_descs.empty()) {
+ iree_string_view_t sig_f = iree_vm_function_reflection_attr(
+ &function, iree_make_cstring_view("f"));
+ RawSignatureParser sig_parser;
+ auto sig_str = sig_parser.FunctionSignatureToString(
+ absl::string_view{sig_f.data, sig_f.size});
+ if (!sig_str.has_value()) {
+ return InvalidArgumentErrorBuilder(IREE_LOC)
+ << "Parsing function signature '" << sig_f.data << "': "
+ << sig_parser.GetError().value_or("<NO ERROR AND NO VALUE>");
+ }
+ return InvalidArgumentErrorBuilder(IREE_LOC)
+ << "Expected function with no inputs or outputs, but "
+ << function_name << "' has signature '" << sig_str.value() << "'";
+ }
+ iree_vm_context_t* context = nullptr;
+ // Order matters. The input module will likely be dependent on the hal and
+ // check modules.
+ std::array<iree_vm_module_t*, 3> modules = {hal_module, check_module,
+ input_module};
+ RETURN_IF_ERROR(FromApiStatus(iree_vm_context_create_with_modules(
+ instance, modules.data(), modules.size(),
+ IREE_ALLOCATOR_SYSTEM, &context),
+ IREE_LOC))
+ << "creating context";
+ std::cout << "EXEC @" << function_name << "\n";
+ // Still release the context even if invocation failed to avoid leaks.
+ auto status = Annotate(
+ FromApiStatus(iree_vm_invoke(context, function, /*policy=*/nullptr,
+ /*inputs=*/nullptr, /*outputs=*/nullptr,
+ IREE_ALLOCATOR_SYSTEM),
+ IREE_LOC),
+ absl::StrCat("invoking function ", function_name));
+ RETURN_IF_ERROR(FromApiStatus(iree_vm_context_release(context), IREE_LOC));
+ return status;
+ };
+ Status evaluate_status = OkStatus();
+ auto module_signature = iree_vm_module_signature(input_module);
+ for (int i = 0; i < module_signature.export_function_count; ++i) {
+ evaluate_status = run_function(i);
+ if (!evaluate_status.ok()) {
+ break;
+ }
+ }
+
+ // TODO(gcmn): Some nice wrappers to make this pattern shorter with generated
+ // error messages.
+ // Deallocate:
+ RETURN_IF_ERROR(FromApiStatus(iree_vm_module_release(hal_module), IREE_LOC));
+ RETURN_IF_ERROR(
+ FromApiStatus(iree_vm_module_release(check_module), IREE_LOC));
+ RETURN_IF_ERROR(
+ FromApiStatus(iree_vm_module_release(input_module), IREE_LOC));
+ RETURN_IF_ERROR(FromApiStatus(iree_hal_device_release(device), IREE_LOC));
+ RETURN_IF_ERROR(FromApiStatus(iree_vm_instance_release(instance), IREE_LOC));
+
+ if (absl::GetFlag(FLAGS_expect_failure)) {
+ if (evaluate_status.ok()) {
+ return InvalidArgumentErrorBuilder(IREE_LOC)
+ << "Test passed but expected failure";
+ }
+ std::cout << "Test failed as expected " << evaluate_status << "\n";
+ return OkStatus();
+ }
+ return evaluate_status;
+}
+
+} // namespace
+
+extern "C" int main(int argc, char** argv) {
+ InitializeEnvironment(&argc, &argv);
+ auto status = Run();
+ if (!status.ok()) {
+ LOG(ERROR) << status << "\n";
+ return 1;
+ }
+ return 0;
+}
+
+} // namespace iree
diff --git a/iree/modules/check/dialect/CMakeLists.txt b/iree/modules/check/dialect/CMakeLists.txt
index 9bbbc5c..5b5253a 100644
--- a/iree/modules/check/dialect/CMakeLists.txt
+++ b/iree/modules/check/dialect/CMakeLists.txt
@@ -81,6 +81,8 @@
iree_cc_binary(
NAME
check-translate
+ OUT
+ check-translate
DEPS
::dialect
iree::tools::iree_translate_library
diff --git a/iree/modules/check/test/BUILD b/iree/modules/check/test/BUILD
new file mode 100644
index 0000000..9e1bd9f
--- /dev/null
+++ b/iree/modules/check/test/BUILD
@@ -0,0 +1,29 @@
+# Copyright 2020 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+load("//iree:lit_test.bzl", "iree_lit_test_suite")
+
+package(
+ default_visibility = ["//visibility:public"],
+ licenses = ["notice"], # Apache 2.0
+)
+
+iree_lit_test_suite(
+ name = "lit",
+ srcs = glob(["*.mlir"]),
+ data = [
+ "//iree/modules/check:iree-check-module",
+ "//iree/modules/check/dialect:check-translate",
+ ],
+)
diff --git a/iree/modules/check/test/CMakeLists.txt b/iree/modules/check/test/CMakeLists.txt
new file mode 100644
index 0000000..94bd449
--- /dev/null
+++ b/iree/modules/check/test/CMakeLists.txt
@@ -0,0 +1,24 @@
+# Copyright 2020 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+file(GLOB _GLOB_X_MLIR CONFIGURE_DEPENDS *.mlir)
+iree_lit_test_suite(
+ NAME
+ lit
+ SRCS
+ "${_GLOB_X_MLIR}"
+ DATA
+ iree::modules::check::dialect::check-translate
+ iree::modules::check::iree-check-module
+)
diff --git a/iree/modules/check/test/failure.mlir b/iree/modules/check/test/failure.mlir
new file mode 100644
index 0000000..b97e29d
--- /dev/null
+++ b/iree/modules/check/test/failure.mlir
@@ -0,0 +1,8 @@
+// RUN: check-translate --iree-hal-target-backends=interpreter-bytecode -iree-mlir-to-vm-bytecode-module %s | iree-check-module --expect_failure
+// RUN: [[ $IREE_VULKAN_DISABLE == 1 ]] || (check-translate --iree-hal-target-backends=vulkan-spirv -iree-mlir-to-vm-bytecode-module %s | iree-check-module --driver=vulkan --expect_failure)
+
+func @check_false() attributes { iree.module.export } {
+ %false = iree.unfoldable_constant 0 : i32
+ check.expect_true(%false) : i32
+ return
+}
diff --git a/iree/modules/check/test/simple.mlir b/iree/modules/check/test/simple.mlir
new file mode 100644
index 0000000..ab0f3eb
--- /dev/null
+++ b/iree/modules/check/test/simple.mlir
@@ -0,0 +1,19 @@
+// RUN: check-translate --iree-hal-target-backends=interpreter-bytecode -iree-mlir-to-vm-bytecode-module %s | iree-check-module --driver="interpreter"
+// RUN: [[ $IREE_VULKAN_DISABLE == 1 ]] || (check-translate --iree-hal-target-backends=vulkan-spirv -iree-mlir-to-vm-bytecode-module %s | iree-check-module --driver=vulkan)
+
+func @check_true() attributes { iree.module.export } {
+ %true = iree.unfoldable_constant 1 : i32
+ check.expect_true(%true) : i32
+ return
+}
+
+func @abs() attributes { iree.module.export } {
+ %cm5 = iree.unfoldable_constant dense<-5> : tensor<i32>
+ %result = "xla_hlo.abs"(%cm5) : (tensor<i32>) -> tensor<i32>
+ %c5 = iree.unfoldable_constant dense<5> : tensor<i32>
+ %eq = "xla_hlo.compare"(%result, %c5) {comparison_direction = "EQ"} : (tensor<i32>, tensor<i32>) -> tensor<i1>
+ %eq_el = extract_element %eq[] : tensor<i1>
+ check.expect_true(%eq_el) : i1
+ return
+}
+
diff --git a/iree/tools/BUILD b/iree/tools/BUILD
index 1d8e807..d203e44 100644
--- a/iree/tools/BUILD
+++ b/iree/tools/BUILD
@@ -218,11 +218,11 @@
data = ["@llvm-project//llvm:FileCheck"],
)
+# TODO(b/146898896): Refactor these into more coherent packages.
cc_library(
name = "vm_util",
srcs = ["vm_util.cc"],
hdrs = ["vm_util.h"],
- visibility = ["//visibility:private"],
deps = [
"//iree/base:api_util",
"//iree/base:buffer_string_util",