A test runner for IREE modules

This is still just a prototype. It's not ready to take over for all our tests as it's still missing nice per-test logging, infra to make it difficult to accidentally check in a no-op test, and something less hacky than relying on status propagation.

PiperOrigin-RevId: 296287006
diff --git a/iree/modules/check/BUILD b/iree/modules/check/BUILD
index ba91384..7d36b4c 100644
--- a/iree/modules/check/BUILD
+++ b/iree/modules/check/BUILD
@@ -13,12 +13,22 @@
 # limitations under the License.
 
 load("//iree/tools:compilation.bzl", "iree_bytecode_module")
+load("//iree:build_defs.oss.bzl", "PLATFORM_VULKAN_DEPS")
 
 package(
     default_visibility = ["//visibility:public"],
     licenses = ["notice"],  # Apache 2.0
 )
 
+# Driver modules that register themselves at link time.
+IREE_DRIVER_MODULES = [
+    "//iree/hal/interpreter:interpreter_driver_module",
+    # TODO(b/142004903): enable when Dawn HAL implementation is functional
+    # "//iree/hal/dawn:dawn_driver_module",
+    "//iree/hal/vmla:vmla_driver_module",
+    "//iree/hal/vulkan:vulkan_driver_module",
+]
+
 iree_bytecode_module(
     name = "check_test_module",
     src = "check_test.mlir",
@@ -46,6 +56,25 @@
     ],
 )
 
+cc_binary(
+    name = "iree-check-module",
+    srcs = ["check_module_main.cc"],
+    deps = [
+        ":native_module",
+        "//iree/base:api",
+        "//iree/tools:vm_util",
+        "@com_google_absl//absl/flags:flag",
+        "@com_google_absl//absl/strings",
+        "//iree/base:api_util",
+        "//iree/base:file_io",
+        "//iree/base:init",
+        "//iree/base:source_location",
+        "//iree/base:status",
+        "//iree/modules/hal",
+        "//iree/vm:bytecode_module",
+    ] + PLATFORM_VULKAN_DEPS + IREE_DRIVER_MODULES,
+)
+
 cc_library(
     name = "native_module",
     srcs = ["native_module.cc"],
diff --git a/iree/modules/check/CMakeLists.txt b/iree/modules/check/CMakeLists.txt
index 998a543..01f9651 100644
--- a/iree/modules/check/CMakeLists.txt
+++ b/iree/modules/check/CMakeLists.txt
@@ -13,6 +13,7 @@
 # limitations under the License.
 
 add_subdirectory(dialect)
+add_subdirectory(test)
 
 iree_bytecode_module(
   NAME
@@ -48,6 +49,34 @@
     iree::vm::ref
 )
 
+iree_cc_binary(
+  NAME
+    iree-check-module
+  SRCS
+    "check_module_main.cc"
+  OUT
+    iree-check-module
+  DEPS
+    ::native_module
+    absl::flags
+    absl::strings
+    iree::base::api
+    iree::base::api_util
+    iree::base::file_io
+    iree::base::init
+    iree::base::source_location
+    iree::base::status
+    # TODO(marbre): Add PLATFORM_VULKAN_DEPS
+    iree::hal::interpreter::interpreter_driver_module
+    iree::hal::vmla::vmla_driver_module
+    iree::hal::vulkan::vulkan_driver_module
+    iree::modules::hal
+    iree::tools::vm_util
+    iree::vm::bytecode_module
+)
+add_executable(iree-check-module ALIAS iree_modules_check_iree-check-module)
+
+
 iree_cc_library(
   NAME
     native_module
diff --git a/iree/modules/check/check_module_main.cc b/iree/modules/check/check_module_main.cc
new file mode 100644
index 0000000..330ab28
--- /dev/null
+++ b/iree/modules/check/check_module_main.cc
@@ -0,0 +1,170 @@
+// Copyright 2020 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//      https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <iostream>
+
+#include "absl/flags/flag.h"
+#include "absl/strings/string_view.h"
+#include "iree/base/api.h"
+#include "iree/base/api_util.h"
+#include "iree/base/file_io.h"
+#include "iree/base/init.h"
+#include "iree/base/source_location.h"
+#include "iree/base/status.h"
+#include "iree/modules/check/native_module.h"
+#include "iree/modules/hal/hal_module.h"
+#include "iree/tools/vm_util.h"
+#include "iree/vm/bytecode_module.h"
+
+ABSL_FLAG(std::string, input_file, "-",
+          "File containing the module to load that contains the entry "
+          "function. Defaults to stdin.");
+
+ABSL_FLAG(std::string, driver, "interpreter", "Backend driver to use.");
+
+ABSL_FLAG(
+    bool, expect_failure, false,
+    "Whether running module is expected to fail. If set, failing "
+    "statuses from function evaluation are logged and ignored and all "
+    "evaluations succeeding is considered an error and will return a failure.");
+
+namespace iree {
+namespace {
+
+StatusOr<std::string> GetModuleContentsFromFlags() {
+  auto input_file = absl::GetFlag(FLAGS_input_file);
+  std::string contents;
+  if (input_file == "-") {
+    contents = std::string{std::istreambuf_iterator<char>(std::cin),
+                           std::istreambuf_iterator<char>()};
+  } else {
+    ASSIGN_OR_RETURN(contents, file_io::GetFileContents(input_file));
+  }
+  return contents;
+}
+
+Status Run() {
+  RETURN_IF_ERROR(FromApiStatus(iree_hal_module_register_types(), IREE_LOC))
+      << "registering HAL types";
+  iree_vm_instance_t* instance = nullptr;
+  RETURN_IF_ERROR(FromApiStatus(
+      iree_vm_instance_create(IREE_ALLOCATOR_SYSTEM, &instance), IREE_LOC))
+      << "creating instance";
+
+  ASSIGN_OR_RETURN(auto module_data, GetModuleContentsFromFlags());
+  iree_vm_module_t* input_module = nullptr;
+  RETURN_IF_ERROR(LoadBytecodeModule(module_data, &input_module));
+
+  iree_hal_device_t* device = nullptr;
+  RETURN_IF_ERROR(CreateDevice(absl::GetFlag(FLAGS_driver), &device));
+  iree_vm_module_t* hal_module = nullptr;
+  RETURN_IF_ERROR(CreateHalModule(device, &hal_module));
+  iree_vm_module_t* check_module = nullptr;
+  check_native_module_create(IREE_ALLOCATOR_SYSTEM, &check_module);
+
+  auto run_function = [&](int ordinal) -> Status {
+    iree_vm_function_t function;
+    RETURN_IF_ERROR(FromApiStatus(
+        iree_vm_module_lookup_function_by_ordinal(
+            input_module, IREE_VM_FUNCTION_LINKAGE_EXPORT, ordinal, &function),
+        IREE_LOC))
+        << "Looking up function export " << ordinal;
+    iree_string_view_t function_name_iree_sv = iree_vm_function_name(&function);
+    // TODO(gcmn): Implicit conversion from iree to absl string view.
+    auto function_name = absl::string_view(function_name_iree_sv.data,
+                                           function_name_iree_sv.size);
+
+    RETURN_IF_ERROR(ValidateFunctionAbi(function));
+    ASSIGN_OR_RETURN(auto input_descs, ParseInputSignature(function));
+    ASSIGN_OR_RETURN(auto output_descs, ParseOutputSignature(function));
+    if (!input_descs.empty() || !output_descs.empty()) {
+      iree_string_view_t sig_f = iree_vm_function_reflection_attr(
+          &function, iree_make_cstring_view("f"));
+      RawSignatureParser sig_parser;
+      auto sig_str = sig_parser.FunctionSignatureToString(
+          absl::string_view{sig_f.data, sig_f.size});
+      if (!sig_str.has_value()) {
+        return InvalidArgumentErrorBuilder(IREE_LOC)
+               << "Parsing function signature '" << sig_f.data << "': "
+               << sig_parser.GetError().value_or("<NO ERROR AND NO VALUE>");
+      }
+      return InvalidArgumentErrorBuilder(IREE_LOC)
+             << "Expected function with no inputs or outputs, but "
+             << function_name << "' has signature '" << sig_str.value() << "'";
+    }
+    iree_vm_context_t* context = nullptr;
+    // Order matters. The input module will likely be dependent on the hal and
+    // check modules.
+    std::array<iree_vm_module_t*, 3> modules = {hal_module, check_module,
+                                                input_module};
+    RETURN_IF_ERROR(FromApiStatus(iree_vm_context_create_with_modules(
+                                      instance, modules.data(), modules.size(),
+                                      IREE_ALLOCATOR_SYSTEM, &context),
+                                  IREE_LOC))
+        << "creating context";
+    std::cout << "EXEC @" << function_name << "\n";
+    // Still release the context even if invocation failed to avoid leaks.
+    auto status = Annotate(
+        FromApiStatus(iree_vm_invoke(context, function, /*policy=*/nullptr,
+                                     /*inputs=*/nullptr, /*outputs=*/nullptr,
+                                     IREE_ALLOCATOR_SYSTEM),
+                      IREE_LOC),
+        absl::StrCat("invoking function ", function_name));
+    RETURN_IF_ERROR(FromApiStatus(iree_vm_context_release(context), IREE_LOC));
+    return status;
+  };
+  Status evaluate_status = OkStatus();
+  auto module_signature = iree_vm_module_signature(input_module);
+  for (int i = 0; i < module_signature.export_function_count; ++i) {
+    evaluate_status = run_function(i);
+    if (!evaluate_status.ok()) {
+      break;
+    }
+  }
+
+  // TODO(gcmn): Some nice wrappers to make this pattern shorter with generated
+  // error messages.
+  // Deallocate:
+  RETURN_IF_ERROR(FromApiStatus(iree_vm_module_release(hal_module), IREE_LOC));
+  RETURN_IF_ERROR(
+      FromApiStatus(iree_vm_module_release(check_module), IREE_LOC));
+  RETURN_IF_ERROR(
+      FromApiStatus(iree_vm_module_release(input_module), IREE_LOC));
+  RETURN_IF_ERROR(FromApiStatus(iree_hal_device_release(device), IREE_LOC));
+  RETURN_IF_ERROR(FromApiStatus(iree_vm_instance_release(instance), IREE_LOC));
+
+  if (absl::GetFlag(FLAGS_expect_failure)) {
+    if (evaluate_status.ok()) {
+      return InvalidArgumentErrorBuilder(IREE_LOC)
+             << "Test passed but expected failure";
+    }
+    std::cout << "Test failed as expected " << evaluate_status << "\n";
+    return OkStatus();
+  }
+  return evaluate_status;
+}
+
+}  // namespace
+
+extern "C" int main(int argc, char** argv) {
+  InitializeEnvironment(&argc, &argv);
+  auto status = Run();
+  if (!status.ok()) {
+    LOG(ERROR) << status << "\n";
+    return 1;
+  }
+  return 0;
+}
+
+}  // namespace iree
diff --git a/iree/modules/check/dialect/CMakeLists.txt b/iree/modules/check/dialect/CMakeLists.txt
index 9bbbc5c..5b5253a 100644
--- a/iree/modules/check/dialect/CMakeLists.txt
+++ b/iree/modules/check/dialect/CMakeLists.txt
@@ -81,6 +81,8 @@
 iree_cc_binary(
   NAME
     check-translate
+  OUT
+    check-translate
   DEPS
     ::dialect
     iree::tools::iree_translate_library
diff --git a/iree/modules/check/test/BUILD b/iree/modules/check/test/BUILD
new file mode 100644
index 0000000..9e1bd9f
--- /dev/null
+++ b/iree/modules/check/test/BUILD
@@ -0,0 +1,29 @@
+# Copyright 2020 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#      https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+load("//iree:lit_test.bzl", "iree_lit_test_suite")
+
+package(
+    default_visibility = ["//visibility:public"],
+    licenses = ["notice"],  # Apache 2.0
+)
+
+iree_lit_test_suite(
+    name = "lit",
+    srcs = glob(["*.mlir"]),
+    data = [
+        "//iree/modules/check:iree-check-module",
+        "//iree/modules/check/dialect:check-translate",
+    ],
+)
diff --git a/iree/modules/check/test/CMakeLists.txt b/iree/modules/check/test/CMakeLists.txt
new file mode 100644
index 0000000..94bd449
--- /dev/null
+++ b/iree/modules/check/test/CMakeLists.txt
@@ -0,0 +1,24 @@
+# Copyright 2020 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#      https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+file(GLOB _GLOB_X_MLIR CONFIGURE_DEPENDS *.mlir)
+iree_lit_test_suite(
+  NAME
+    lit
+  SRCS
+    "${_GLOB_X_MLIR}"
+  DATA
+    iree::modules::check::dialect::check-translate
+    iree::modules::check::iree-check-module
+)
diff --git a/iree/modules/check/test/failure.mlir b/iree/modules/check/test/failure.mlir
new file mode 100644
index 0000000..b97e29d
--- /dev/null
+++ b/iree/modules/check/test/failure.mlir
@@ -0,0 +1,8 @@
+// RUN: check-translate --iree-hal-target-backends=interpreter-bytecode -iree-mlir-to-vm-bytecode-module %s | iree-check-module --expect_failure
+// RUN: [[ $IREE_VULKAN_DISABLE == 1 ]] || (check-translate --iree-hal-target-backends=vulkan-spirv -iree-mlir-to-vm-bytecode-module %s | iree-check-module --driver=vulkan --expect_failure)
+
+func @check_false() attributes { iree.module.export } {
+  %false = iree.unfoldable_constant 0 : i32
+  check.expect_true(%false) : i32
+  return
+}
diff --git a/iree/modules/check/test/simple.mlir b/iree/modules/check/test/simple.mlir
new file mode 100644
index 0000000..ab0f3eb
--- /dev/null
+++ b/iree/modules/check/test/simple.mlir
@@ -0,0 +1,19 @@
+// RUN: check-translate --iree-hal-target-backends=interpreter-bytecode -iree-mlir-to-vm-bytecode-module %s | iree-check-module --driver="interpreter"
+// RUN: [[ $IREE_VULKAN_DISABLE == 1 ]] || (check-translate --iree-hal-target-backends=vulkan-spirv -iree-mlir-to-vm-bytecode-module %s | iree-check-module --driver=vulkan)
+
+func @check_true() attributes { iree.module.export } {
+  %true = iree.unfoldable_constant 1 : i32
+  check.expect_true(%true) : i32
+  return
+}
+
+func @abs() attributes { iree.module.export } {
+  %cm5 = iree.unfoldable_constant dense<-5> : tensor<i32>
+  %result = "xla_hlo.abs"(%cm5) : (tensor<i32>) -> tensor<i32>
+  %c5 = iree.unfoldable_constant dense<5> : tensor<i32>
+  %eq = "xla_hlo.compare"(%result, %c5) {comparison_direction = "EQ"} : (tensor<i32>, tensor<i32>) -> tensor<i1>
+  %eq_el = extract_element %eq[] : tensor<i1>
+  check.expect_true(%eq_el) : i1
+  return
+}
+
diff --git a/iree/tools/BUILD b/iree/tools/BUILD
index 1d8e807..d203e44 100644
--- a/iree/tools/BUILD
+++ b/iree/tools/BUILD
@@ -218,11 +218,11 @@
     data = ["@llvm-project//llvm:FileCheck"],
 )
 
+# TODO(b/146898896): Refactor these into more coherent packages.
 cc_library(
     name = "vm_util",
     srcs = ["vm_util.cc"],
     hdrs = ["vm_util.h"],
-    visibility = ["//visibility:private"],
     deps = [
         "//iree/base:api_util",
         "//iree/base:buffer_string_util",