Thread check runner and dialect into gtest.

This now produces nice structured output.

I spent quite a while fiddling around to see if I could do the setup and teardown through gtest fixtures or environments, but I couldn't find a way since I also need the input module to register the tests.

Note this includes adding slightly more to our gtest shim. I could split it out into a separate header, but it's not very big and then I'd argue we should really split gmock and gtest.

PiperOrigin-RevId: 298938652
diff --git a/iree/modules/check/BUILD b/iree/modules/check/BUILD
index caea9f7..c968a1e 100644
--- a/iree/modules/check/BUILD
+++ b/iree/modules/check/BUILD
@@ -49,6 +49,7 @@
 
 cc_binary(
     name = "iree-check-module",
+    testonly = True,
     srcs = ["check_module_main.cc"],
     deps = [
         ":native_module",
@@ -62,6 +63,7 @@
         "//iree/base:source_location",
         "//iree/base:status",
         "//iree/modules/hal",
+        "//iree/testing:gtest",
         "//iree/tools:vm_util",
         "//iree/vm:bytecode_module",
     ] + PLATFORM_VULKAN_DEPS + IREE_DRIVER_MODULES,
@@ -69,6 +71,7 @@
 
 cc_library(
     name = "native_module",
+    testonly = True,
     srcs = ["native_module.cc"],
     hdrs = ["native_module.h"],
     deps = [
@@ -77,6 +80,7 @@
         "//iree/base:buffer_string_util",
         "//iree/hal:api",
         "//iree/modules/hal",
+        "//iree/testing:gtest",
         "//iree/vm",
         "//iree/vm:module_abi_cc",
         "@com_google_absl//absl/strings",
diff --git a/iree/modules/check/CMakeLists.txt b/iree/modules/check/CMakeLists.txt
index 06f30e3..bdcebf3 100644
--- a/iree/modules/check/CMakeLists.txt
+++ b/iree/modules/check/CMakeLists.txt
@@ -72,8 +72,10 @@
     iree::hal::vmla::vmla_driver_module
     iree::hal::vulkan::vulkan_driver_module
     iree::modules::hal
+    iree::testing::gtest
     iree::tools::vm_util
     iree::vm::bytecode_module
+  TESTONLY
 )
 
 iree_cc_library(
@@ -90,7 +92,9 @@
     iree::base::buffer_string_util
     iree::hal::api
     iree::modules::hal
+    iree::testing::gtest
     iree::vm
     iree::vm::module_abi_cc
+  TESTONLY
   PUBLIC
 )
diff --git a/iree/modules/check/check_module_main.cc b/iree/modules/check/check_module_main.cc
index 330ab28..88447b7 100644
--- a/iree/modules/check/check_module_main.cc
+++ b/iree/modules/check/check_module_main.cc
@@ -24,6 +24,7 @@
 #include "iree/base/status.h"
 #include "iree/modules/check/native_module.h"
 #include "iree/modules/hal/hal_module.h"
+#include "iree/testing/gtest.h"
 #include "iree/tools/vm_util.h"
 #include "iree/vm/bytecode_module.h"
 
@@ -37,7 +38,8 @@
     bool, expect_failure, false,
     "Whether running module is expected to fail. If set, failing "
     "statuses from function evaluation are logged and ignored and all "
-    "evaluations succeeding is considered an error and will return a failure.");
+    "evaluations succeeding is considered an error and will return a failure. "
+    "Mostly useful for testing the binary doesn't crash for failing tests.");
 
 namespace iree {
 namespace {
@@ -54,7 +56,36 @@
   return contents;
 }
 
-Status Run() {
+class CheckModuleTest : public ::testing::Test {
+ public:
+  explicit CheckModuleTest(iree_vm_instance_t* instance,
+                           std::array<iree_vm_module_t*, 3> modules,
+                           iree_vm_function_t function)
+      : instance_(instance), modules_(modules), function_(function) {}
+  void SetUp() override {
+    IREE_ASSERT_OK(iree_vm_context_create_with_modules(
+        instance_, modules_.data(), modules_.size(), IREE_ALLOCATOR_SYSTEM,
+        &context_));
+  }
+  void TearDown() override {
+    IREE_ASSERT_OK(iree_vm_context_release(context_));
+  }
+
+  void TestBody() override {
+    IREE_EXPECT_OK(iree_vm_invoke(context_, function_, /*policy=*/nullptr,
+                                  /*inputs=*/nullptr, /*outputs=*/nullptr,
+                                  IREE_ALLOCATOR_SYSTEM));
+  }
+
+ private:
+  iree_vm_instance_t* instance_ = nullptr;
+  std::array<iree_vm_module_t*, 3> modules_;
+  iree_vm_function_t function_;
+
+  iree_vm_context_t* context_ = nullptr;
+};
+
+StatusOr<int> Run() {
   RETURN_IF_ERROR(FromApiStatus(iree_hal_module_register_types(), IREE_LOC))
       << "registering HAL types";
   iree_vm_instance_t* instance = nullptr;
@@ -73,18 +104,27 @@
   iree_vm_module_t* check_module = nullptr;
   check_native_module_create(IREE_ALLOCATOR_SYSTEM, &check_module);
 
-  auto run_function = [&](int ordinal) -> Status {
+  std::array<iree_vm_module_t*, 3> modules = {hal_module, check_module,
+                                              input_module};
+  auto module_signature = iree_vm_module_signature(input_module);
+  for (int ordinal = 0; ordinal < module_signature.export_function_count;
+       ++ordinal) {
     iree_vm_function_t function;
     RETURN_IF_ERROR(FromApiStatus(
         iree_vm_module_lookup_function_by_ordinal(
             input_module, IREE_VM_FUNCTION_LINKAGE_EXPORT, ordinal, &function),
         IREE_LOC))
         << "Looking up function export " << ordinal;
+
     iree_string_view_t function_name_iree_sv = iree_vm_function_name(&function);
     // TODO(gcmn): Implicit conversion from iree to absl string view.
     auto function_name = absl::string_view(function_name_iree_sv.data,
                                            function_name_iree_sv.size);
 
+    iree_string_view_t module_name_iree_sv = iree_vm_module_name(input_module);
+    auto module_name =
+        absl::string_view(module_name_iree_sv.data, module_name_iree_sv.size);
+
     RETURN_IF_ERROR(ValidateFunctionAbi(function));
     ASSIGN_OR_RETURN(auto input_descs, ParseInputSignature(function));
     ASSIGN_OR_RETURN(auto output_descs, ParseOutputSignature(function));
@@ -103,36 +143,19 @@
              << "Expected function with no inputs or outputs, but "
              << function_name << "' has signature '" << sig_str.value() << "'";
     }
-    iree_vm_context_t* context = nullptr;
-    // Order matters. The input module will likely be dependent on the hal and
-    // check modules.
-    std::array<iree_vm_module_t*, 3> modules = {hal_module, check_module,
-                                                input_module};
-    RETURN_IF_ERROR(FromApiStatus(iree_vm_context_create_with_modules(
-                                      instance, modules.data(), modules.size(),
-                                      IREE_ALLOCATOR_SYSTEM, &context),
-                                  IREE_LOC))
-        << "creating context";
-    std::cout << "EXEC @" << function_name << "\n";
-    // Still release the context even if invocation failed to avoid leaks.
-    auto status = Annotate(
-        FromApiStatus(iree_vm_invoke(context, function, /*policy=*/nullptr,
-                                     /*inputs=*/nullptr, /*outputs=*/nullptr,
-                                     IREE_ALLOCATOR_SYSTEM),
-                      IREE_LOC),
-        absl::StrCat("invoking function ", function_name));
-    RETURN_IF_ERROR(FromApiStatus(iree_vm_context_release(context), IREE_LOC));
-    return status;
-  };
-  Status evaluate_status = OkStatus();
-  auto module_signature = iree_vm_module_signature(input_module);
-  for (int i = 0; i < module_signature.export_function_count; ++i) {
-    evaluate_status = run_function(i);
-    if (!evaluate_status.ok()) {
-      break;
-    }
-  }
 
+    ::testing::RegisterTest(
+        module_name.data(), function_name.data(), nullptr,
+        std::to_string(ordinal).c_str(), __FILE__, __LINE__,
+        [&instance, modules, function]() -> CheckModuleTest* {
+          return new CheckModuleTest(instance, modules, function);
+        });
+  }
+  int ret = RUN_ALL_TESTS();
+
+  // TODO(b/146898896): Investigate mechanism for sharing state between tests
+  // that happens before test registration (we need the input module) and has
+  // nice setup/teardown split.
   // TODO(gcmn): Some nice wrappers to make this pattern shorter with generated
   // error messages.
   // Deallocate:
@@ -144,27 +167,26 @@
   RETURN_IF_ERROR(FromApiStatus(iree_hal_device_release(device), IREE_LOC));
   RETURN_IF_ERROR(FromApiStatus(iree_vm_instance_release(instance), IREE_LOC));
 
-  if (absl::GetFlag(FLAGS_expect_failure)) {
-    if (evaluate_status.ok()) {
-      return InvalidArgumentErrorBuilder(IREE_LOC)
-             << "Test passed but expected failure";
-    }
-    std::cout << "Test failed as expected " << evaluate_status << "\n";
-    return OkStatus();
-  }
-  return evaluate_status;
+  return ret;
 }
 
 }  // namespace
 
 extern "C" int main(int argc, char** argv) {
   InitializeEnvironment(&argc, &argv);
-  auto status = Run();
-  if (!status.ok()) {
-    LOG(ERROR) << status << "\n";
-    return 1;
+
+  int ret = Run().ValueOrDie();
+
+  if (absl::GetFlag(FLAGS_expect_failure)) {
+    if (ret == 0) {
+      std::cout << "Test passed but expected failure\n";
+      return 1;
+    }
+    std::cout << "Test failed as expected\n";
+    return 0;
   }
-  return 0;
+
+  return ret;
 }
 
 }  // namespace iree
diff --git a/iree/modules/check/check_test.cc b/iree/modules/check/check_test.cc
index cdae4d5..2ce3935 100644
--- a/iree/modules/check/check_test.cc
+++ b/iree/modules/check/check_test.cc
@@ -98,10 +98,10 @@
   IREE_ASSERT_OK(
       iree_vm_variant_list_alloc(1, IREE_ALLOCATOR_SYSTEM, &inputs_));
   IREE_ASSERT_OK(iree_vm_variant_list_append_value(inputs_, arg));
-  auto status = iree_vm_invoke(context_, LookupFunction("expectTrue"),
-                               /*policy=*/nullptr, inputs_, outputs_,
-                               IREE_ALLOCATOR_SYSTEM);
-  ASSERT_NE(IREE_STATUS_OK, status);
+  EXPECT_NONFATAL_FAILURE(
+      IREE_ASSERT_OK(iree_vm_invoke(context_, LookupFunction("expectTrue"),
+                                    /*policy=*/nullptr, inputs_, outputs_,
+                                    IREE_ALLOCATOR_SYSTEM)),
+      "Expected 0 to be nonzero");
 }
-
 }  // namespace
diff --git a/iree/modules/check/native_module.cc b/iree/modules/check/native_module.cc
index 8b181e8..bc8a3f9 100644
--- a/iree/modules/check/native_module.cc
+++ b/iree/modules/check/native_module.cc
@@ -23,6 +23,7 @@
 #include "iree/base/buffer_string_util.h"
 #include "iree/hal/api.h"
 #include "iree/modules/hal/hal_module.h"
+#include "iree/testing/gtest.h"
 #include "iree/vm/module_abi_cc.h"
 
 //===----------------------------------------------------------------------===//
@@ -45,9 +46,7 @@
   ~CheckModuleState() = default;
 
   Status ExpectTrue(int32_t val) {
-    if (val == 0)
-      return Status(StatusCode::kInvalidArgument,
-                    absl::StrCat("Expected ", val, " to be non-zero."));
+    EXPECT_TRUE(val) << "Expected " << val << " to be nonzero.";
     return OkStatus();
   }
 
diff --git a/iree/modules/check/test/BUILD b/iree/modules/check/test/BUILD
index 9e1bd9f..d132f98 100644
--- a/iree/modules/check/test/BUILD
+++ b/iree/modules/check/test/BUILD
@@ -25,5 +25,6 @@
     data = [
         "//iree/modules/check:iree-check-module",
         "//iree/modules/check/dialect:check-translate",
+        "//iree/tools:IreeFileCheck",
     ],
 )
diff --git a/iree/modules/check/test/CMakeLists.txt b/iree/modules/check/test/CMakeLists.txt
index 94bd449..89ac22c 100644
--- a/iree/modules/check/test/CMakeLists.txt
+++ b/iree/modules/check/test/CMakeLists.txt
@@ -21,4 +21,5 @@
   DATA
     iree::modules::check::dialect::check-translate
     iree::modules::check::iree-check-module
+    iree::tools::IreeFileCheck
 )
diff --git a/iree/modules/check/test/failure.mlir b/iree/modules/check/test/failure.mlir
index b97e29d..0d13e65 100644
--- a/iree/modules/check/test/failure.mlir
+++ b/iree/modules/check/test/failure.mlir
@@ -1,8 +1,13 @@
-// RUN: check-translate --iree-hal-target-backends=interpreter-bytecode -iree-mlir-to-vm-bytecode-module %s | iree-check-module --expect_failure
-// RUN: [[ $IREE_VULKAN_DISABLE == 1 ]] || (check-translate --iree-hal-target-backends=vulkan-spirv -iree-mlir-to-vm-bytecode-module %s | iree-check-module --driver=vulkan --expect_failure)
+// RUN: check-translate --iree-hal-target-backends=interpreter-bytecode -iree-mlir-to-vm-bytecode-module %s | iree-check-module --expect_failure | IreeFileCheck %s
+// RUN: [[ $IREE_VULKAN_DISABLE == 1 ]] || (check-translate --iree-hal-target-backends=vulkan-spirv -iree-mlir-to-vm-bytecode-module %s | iree-check-module --driver=vulkan --expect_failure | IreeFileCheck %s)
 
-func @check_false() attributes { iree.module.export } {
+// CHECK-LABEL: expect_failure.expect_true_of_false
+// CHECK: Expected 0 to be nonzero
+// CHECK: Test failed as expected
+module @expect_failure {
+func @expect_true_of_false() attributes { iree.module.export } {
   %false = iree.unfoldable_constant 0 : i32
   check.expect_true(%false) : i32
   return
 }
+}
diff --git a/iree/testing/internal/gtest_internal.h b/iree/testing/internal/gtest_internal.h
index 140c351..7835c69 100644
--- a/iree/testing/internal/gtest_internal.h
+++ b/iree/testing/internal/gtest_internal.h
@@ -15,7 +15,8 @@
 #ifndef IREE_TESTING_INTERNAL_GTEST_INTERNAL_H_
 #define IREE_TESTING_INTERNAL_GTEST_INTERNAL_H_
 
-#include "gmock/gmock.h"  // IWYU pragma: export
-#include "gtest/gtest.h"  // IWYU pragma: export
+#include "gmock/gmock.h"      // IWYU pragma: export
+#include "gtest/gtest-spi.h"  // IWYU pragma: export
+#include "gtest/gtest.h"      // IWYU pragma: export
 
 #endif  // IREE_TESTING_INTERNAL_GTEST_INTERNAL_H_