Thread check runner and dialect into gtest.
This now produces nice structured output.
I spent quite a while fiddling around to see if I could do the setup and teardown through gtest fixtures or environments, but I couldn't find a way since I also need the input module to register the tests.
Note this includes adding slightly more to our gtest shim. I could split it out into a separate header, but it's not very big and then I'd argue we should really split gmock and gtest.
PiperOrigin-RevId: 298938652
diff --git a/iree/modules/check/BUILD b/iree/modules/check/BUILD
index caea9f7..c968a1e 100644
--- a/iree/modules/check/BUILD
+++ b/iree/modules/check/BUILD
@@ -49,6 +49,7 @@
cc_binary(
name = "iree-check-module",
+ testonly = True,
srcs = ["check_module_main.cc"],
deps = [
":native_module",
@@ -62,6 +63,7 @@
"//iree/base:source_location",
"//iree/base:status",
"//iree/modules/hal",
+ "//iree/testing:gtest",
"//iree/tools:vm_util",
"//iree/vm:bytecode_module",
] + PLATFORM_VULKAN_DEPS + IREE_DRIVER_MODULES,
@@ -69,6 +71,7 @@
cc_library(
name = "native_module",
+ testonly = True,
srcs = ["native_module.cc"],
hdrs = ["native_module.h"],
deps = [
@@ -77,6 +80,7 @@
"//iree/base:buffer_string_util",
"//iree/hal:api",
"//iree/modules/hal",
+ "//iree/testing:gtest",
"//iree/vm",
"//iree/vm:module_abi_cc",
"@com_google_absl//absl/strings",
diff --git a/iree/modules/check/CMakeLists.txt b/iree/modules/check/CMakeLists.txt
index 06f30e3..bdcebf3 100644
--- a/iree/modules/check/CMakeLists.txt
+++ b/iree/modules/check/CMakeLists.txt
@@ -72,8 +72,10 @@
iree::hal::vmla::vmla_driver_module
iree::hal::vulkan::vulkan_driver_module
iree::modules::hal
+ iree::testing::gtest
iree::tools::vm_util
iree::vm::bytecode_module
+ TESTONLY
)
iree_cc_library(
@@ -90,7 +92,9 @@
iree::base::buffer_string_util
iree::hal::api
iree::modules::hal
+ iree::testing::gtest
iree::vm
iree::vm::module_abi_cc
+ TESTONLY
PUBLIC
)
diff --git a/iree/modules/check/check_module_main.cc b/iree/modules/check/check_module_main.cc
index 330ab28..88447b7 100644
--- a/iree/modules/check/check_module_main.cc
+++ b/iree/modules/check/check_module_main.cc
@@ -24,6 +24,7 @@
#include "iree/base/status.h"
#include "iree/modules/check/native_module.h"
#include "iree/modules/hal/hal_module.h"
+#include "iree/testing/gtest.h"
#include "iree/tools/vm_util.h"
#include "iree/vm/bytecode_module.h"
@@ -37,7 +38,8 @@
bool, expect_failure, false,
"Whether running module is expected to fail. If set, failing "
"statuses from function evaluation are logged and ignored and all "
- "evaluations succeeding is considered an error and will return a failure.");
+ "evaluations succeeding is considered an error and will return a failure. "
+ "Mostly useful for testing the binary doesn't crash for failing tests.");
namespace iree {
namespace {
@@ -54,7 +56,36 @@
return contents;
}
-Status Run() {
+class CheckModuleTest : public ::testing::Test {
+ public:
+ explicit CheckModuleTest(iree_vm_instance_t* instance,
+ std::array<iree_vm_module_t*, 3> modules,
+ iree_vm_function_t function)
+ : instance_(instance), modules_(modules), function_(function) {}
+ void SetUp() override {
+ IREE_ASSERT_OK(iree_vm_context_create_with_modules(
+ instance_, modules_.data(), modules_.size(), IREE_ALLOCATOR_SYSTEM,
+ &context_));
+ }
+ void TearDown() override {
+ IREE_ASSERT_OK(iree_vm_context_release(context_));
+ }
+
+ void TestBody() override {
+ IREE_EXPECT_OK(iree_vm_invoke(context_, function_, /*policy=*/nullptr,
+ /*inputs=*/nullptr, /*outputs=*/nullptr,
+ IREE_ALLOCATOR_SYSTEM));
+ }
+
+ private:
+ iree_vm_instance_t* instance_ = nullptr;
+ std::array<iree_vm_module_t*, 3> modules_;
+ iree_vm_function_t function_;
+
+ iree_vm_context_t* context_ = nullptr;
+};
+
+StatusOr<int> Run() {
RETURN_IF_ERROR(FromApiStatus(iree_hal_module_register_types(), IREE_LOC))
<< "registering HAL types";
iree_vm_instance_t* instance = nullptr;
@@ -73,18 +104,27 @@
iree_vm_module_t* check_module = nullptr;
check_native_module_create(IREE_ALLOCATOR_SYSTEM, &check_module);
- auto run_function = [&](int ordinal) -> Status {
+ std::array<iree_vm_module_t*, 3> modules = {hal_module, check_module,
+ input_module};
+ auto module_signature = iree_vm_module_signature(input_module);
+ for (int ordinal = 0; ordinal < module_signature.export_function_count;
+ ++ordinal) {
iree_vm_function_t function;
RETURN_IF_ERROR(FromApiStatus(
iree_vm_module_lookup_function_by_ordinal(
input_module, IREE_VM_FUNCTION_LINKAGE_EXPORT, ordinal, &function),
IREE_LOC))
<< "Looking up function export " << ordinal;
+
iree_string_view_t function_name_iree_sv = iree_vm_function_name(&function);
// TODO(gcmn): Implicit conversion from iree to absl string view.
auto function_name = absl::string_view(function_name_iree_sv.data,
function_name_iree_sv.size);
+ iree_string_view_t module_name_iree_sv = iree_vm_module_name(input_module);
+ auto module_name =
+ absl::string_view(module_name_iree_sv.data, module_name_iree_sv.size);
+
RETURN_IF_ERROR(ValidateFunctionAbi(function));
ASSIGN_OR_RETURN(auto input_descs, ParseInputSignature(function));
ASSIGN_OR_RETURN(auto output_descs, ParseOutputSignature(function));
@@ -103,36 +143,19 @@
<< "Expected function with no inputs or outputs, but "
<< function_name << "' has signature '" << sig_str.value() << "'";
}
- iree_vm_context_t* context = nullptr;
- // Order matters. The input module will likely be dependent on the hal and
- // check modules.
- std::array<iree_vm_module_t*, 3> modules = {hal_module, check_module,
- input_module};
- RETURN_IF_ERROR(FromApiStatus(iree_vm_context_create_with_modules(
- instance, modules.data(), modules.size(),
- IREE_ALLOCATOR_SYSTEM, &context),
- IREE_LOC))
- << "creating context";
- std::cout << "EXEC @" << function_name << "\n";
- // Still release the context even if invocation failed to avoid leaks.
- auto status = Annotate(
- FromApiStatus(iree_vm_invoke(context, function, /*policy=*/nullptr,
- /*inputs=*/nullptr, /*outputs=*/nullptr,
- IREE_ALLOCATOR_SYSTEM),
- IREE_LOC),
- absl::StrCat("invoking function ", function_name));
- RETURN_IF_ERROR(FromApiStatus(iree_vm_context_release(context), IREE_LOC));
- return status;
- };
- Status evaluate_status = OkStatus();
- auto module_signature = iree_vm_module_signature(input_module);
- for (int i = 0; i < module_signature.export_function_count; ++i) {
- evaluate_status = run_function(i);
- if (!evaluate_status.ok()) {
- break;
- }
- }
+ ::testing::RegisterTest(
+ module_name.data(), function_name.data(), nullptr,
+ std::to_string(ordinal).c_str(), __FILE__, __LINE__,
+ [&instance, modules, function]() -> CheckModuleTest* {
+ return new CheckModuleTest(instance, modules, function);
+ });
+ }
+ int ret = RUN_ALL_TESTS();
+
+ // TODO(b/146898896): Investigate mechanism for sharing state between tests
+ // that happens before test registration (we need the input module) and has
+ // nice setup/teardown split.
// TODO(gcmn): Some nice wrappers to make this pattern shorter with generated
// error messages.
// Deallocate:
@@ -144,27 +167,26 @@
RETURN_IF_ERROR(FromApiStatus(iree_hal_device_release(device), IREE_LOC));
RETURN_IF_ERROR(FromApiStatus(iree_vm_instance_release(instance), IREE_LOC));
- if (absl::GetFlag(FLAGS_expect_failure)) {
- if (evaluate_status.ok()) {
- return InvalidArgumentErrorBuilder(IREE_LOC)
- << "Test passed but expected failure";
- }
- std::cout << "Test failed as expected " << evaluate_status << "\n";
- return OkStatus();
- }
- return evaluate_status;
+ return ret;
}
} // namespace
extern "C" int main(int argc, char** argv) {
InitializeEnvironment(&argc, &argv);
- auto status = Run();
- if (!status.ok()) {
- LOG(ERROR) << status << "\n";
- return 1;
+
+ int ret = Run().ValueOrDie();
+
+ if (absl::GetFlag(FLAGS_expect_failure)) {
+ if (ret == 0) {
+ std::cout << "Test passed but expected failure\n";
+ return 1;
+ }
+ std::cout << "Test failed as expected\n";
+ return 0;
}
- return 0;
+
+ return ret;
}
} // namespace iree
diff --git a/iree/modules/check/check_test.cc b/iree/modules/check/check_test.cc
index cdae4d5..2ce3935 100644
--- a/iree/modules/check/check_test.cc
+++ b/iree/modules/check/check_test.cc
@@ -98,10 +98,10 @@
IREE_ASSERT_OK(
iree_vm_variant_list_alloc(1, IREE_ALLOCATOR_SYSTEM, &inputs_));
IREE_ASSERT_OK(iree_vm_variant_list_append_value(inputs_, arg));
- auto status = iree_vm_invoke(context_, LookupFunction("expectTrue"),
- /*policy=*/nullptr, inputs_, outputs_,
- IREE_ALLOCATOR_SYSTEM);
- ASSERT_NE(IREE_STATUS_OK, status);
+ EXPECT_NONFATAL_FAILURE(
+ IREE_ASSERT_OK(iree_vm_invoke(context_, LookupFunction("expectTrue"),
+ /*policy=*/nullptr, inputs_, outputs_,
+ IREE_ALLOCATOR_SYSTEM)),
+ "Expected 0 to be nonzero");
}
-
} // namespace
diff --git a/iree/modules/check/native_module.cc b/iree/modules/check/native_module.cc
index 8b181e8..bc8a3f9 100644
--- a/iree/modules/check/native_module.cc
+++ b/iree/modules/check/native_module.cc
@@ -23,6 +23,7 @@
#include "iree/base/buffer_string_util.h"
#include "iree/hal/api.h"
#include "iree/modules/hal/hal_module.h"
+#include "iree/testing/gtest.h"
#include "iree/vm/module_abi_cc.h"
//===----------------------------------------------------------------------===//
@@ -45,9 +46,7 @@
~CheckModuleState() = default;
Status ExpectTrue(int32_t val) {
- if (val == 0)
- return Status(StatusCode::kInvalidArgument,
- absl::StrCat("Expected ", val, " to be non-zero."));
+ EXPECT_TRUE(val) << "Expected " << val << " to be nonzero.";
return OkStatus();
}
diff --git a/iree/modules/check/test/BUILD b/iree/modules/check/test/BUILD
index 9e1bd9f..d132f98 100644
--- a/iree/modules/check/test/BUILD
+++ b/iree/modules/check/test/BUILD
@@ -25,5 +25,6 @@
data = [
"//iree/modules/check:iree-check-module",
"//iree/modules/check/dialect:check-translate",
+ "//iree/tools:IreeFileCheck",
],
)
diff --git a/iree/modules/check/test/CMakeLists.txt b/iree/modules/check/test/CMakeLists.txt
index 94bd449..89ac22c 100644
--- a/iree/modules/check/test/CMakeLists.txt
+++ b/iree/modules/check/test/CMakeLists.txt
@@ -21,4 +21,5 @@
DATA
iree::modules::check::dialect::check-translate
iree::modules::check::iree-check-module
+ iree::tools::IreeFileCheck
)
diff --git a/iree/modules/check/test/failure.mlir b/iree/modules/check/test/failure.mlir
index b97e29d..0d13e65 100644
--- a/iree/modules/check/test/failure.mlir
+++ b/iree/modules/check/test/failure.mlir
@@ -1,8 +1,13 @@
-// RUN: check-translate --iree-hal-target-backends=interpreter-bytecode -iree-mlir-to-vm-bytecode-module %s | iree-check-module --expect_failure
-// RUN: [[ $IREE_VULKAN_DISABLE == 1 ]] || (check-translate --iree-hal-target-backends=vulkan-spirv -iree-mlir-to-vm-bytecode-module %s | iree-check-module --driver=vulkan --expect_failure)
+// RUN: check-translate --iree-hal-target-backends=interpreter-bytecode -iree-mlir-to-vm-bytecode-module %s | iree-check-module --expect_failure | IreeFileCheck %s
+// RUN: [[ $IREE_VULKAN_DISABLE == 1 ]] || (check-translate --iree-hal-target-backends=vulkan-spirv -iree-mlir-to-vm-bytecode-module %s | iree-check-module --driver=vulkan --expect_failure | IreeFileCheck %s)
-func @check_false() attributes { iree.module.export } {
+// CHECK-LABEL: expect_failure.expect_true_of_false
+// CHECK: Expected 0 to be nonzero
+// CHECK: Test failed as expected
+module @expect_failure {
+func @expect_true_of_false() attributes { iree.module.export } {
%false = iree.unfoldable_constant 0 : i32
check.expect_true(%false) : i32
return
}
+}
diff --git a/iree/testing/internal/gtest_internal.h b/iree/testing/internal/gtest_internal.h
index 140c351..7835c69 100644
--- a/iree/testing/internal/gtest_internal.h
+++ b/iree/testing/internal/gtest_internal.h
@@ -15,7 +15,8 @@
#ifndef IREE_TESTING_INTERNAL_GTEST_INTERNAL_H_
#define IREE_TESTING_INTERNAL_GTEST_INTERNAL_H_
-#include "gmock/gmock.h" // IWYU pragma: export
-#include "gtest/gtest.h" // IWYU pragma: export
+#include "gmock/gmock.h" // IWYU pragma: export
+#include "gtest/gtest-spi.h" // IWYU pragma: export
+#include "gtest/gtest.h" // IWYU pragma: export
#endif // IREE_TESTING_INTERNAL_GTEST_INTERNAL_H_