blob: aee13e1476d30128c12999fe7947a6b0fc2012ba [file] [log] [blame]
// Copyright 2020 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <iostream>
#include "absl/strings/match.h"
#include "absl/strings/string_view.h"
#include "iree/base/api.h"
#include "iree/base/internal/file_io.h"
#include "iree/base/internal/flags.h"
#include "iree/base/status.h"
#include "iree/base/target_platform.h"
#include "iree/base/tracing.h"
#include "iree/hal/drivers/init.h"
#include "iree/modules/check/native_module.h"
#include "iree/modules/hal/hal_module.h"
#include "iree/testing/gtest.h"
#include "iree/testing/status_matchers.h"
#include "iree/tools/utils/vm_util.h"
#include "iree/vm/bytecode_module.h"
// On Windows stdin defaults to text mode and will get weird line ending
// expansion that will corrupt the input binary.
#if defined(IREE_PLATFORM_WINDOWS)
#include <fcntl.h>
#include <io.h>
#define IREE_FORCE_BINARY_STDIN() _setmode(_fileno(stdin), O_BINARY)
#else
#define IREE_FORCE_BINARY_STDIN()
#endif // IREE_PLATFORM_WINDOWS
IREE_FLAG(string, driver, "vmla", "Backend driver to use.");
IREE_FLAG(
bool, expect_failure, false,
"Whether running module is expected to fail. If set, failing "
"statuses from function evaluation are logged and ignored and all "
"evaluations succeeding is considered an error and will return a failure. "
"Mostly useful for testing the binary doesn't crash for failing tests.");
namespace iree {
namespace {
class CheckModuleTest : public ::testing::Test {
public:
explicit CheckModuleTest(iree_vm_instance_t* instance,
std::array<iree_vm_module_t*, 3> modules,
iree_vm_function_t function)
: instance_(instance), modules_(modules), function_(function) {}
void SetUp() override {
IREE_CHECK_OK(iree_vm_context_create_with_modules(
instance_, modules_.data(), modules_.size(), iree_allocator_system(),
&context_));
}
void TearDown() override { iree_vm_context_release(context_); }
void TestBody() override {
IREE_EXPECT_OK(iree_vm_invoke(context_, function_, /*policy=*/nullptr,
/*inputs=*/nullptr, /*outputs=*/nullptr,
iree_allocator_system()));
}
private:
iree_vm_instance_t* instance_ = nullptr;
std::array<iree_vm_module_t*, 3> modules_;
iree_vm_function_t function_;
iree_vm_context_t* context_ = nullptr;
};
Status Run(std::string module_file_path, int* out_exit_code) {
IREE_TRACE_SCOPE0("iree-check-module");
*out_exit_code = 1;
IREE_RETURN_IF_ERROR(iree_hal_module_register_types(),
"registering HAL types");
iree_vm_instance_t* instance = nullptr;
IREE_RETURN_IF_ERROR(
iree_vm_instance_create(iree_allocator_system(), &instance),
"creating instance");
std::string module_data;
if (module_file_path == "-") {
module_data = std::string{std::istreambuf_iterator<char>(std::cin),
std::istreambuf_iterator<char>()};
} else {
IREE_RETURN_IF_ERROR(
file_io::GetFileContents(module_file_path.c_str(), &module_data));
}
iree_vm_module_t* input_module = nullptr;
IREE_RETURN_IF_ERROR(LoadBytecodeModule(module_data, &input_module));
iree_hal_device_t* device = nullptr;
IREE_RETURN_IF_ERROR(CreateDevice(std::string(FLAG_driver), &device));
iree_vm_module_t* hal_module = nullptr;
IREE_RETURN_IF_ERROR(CreateHalModule(device, &hal_module));
iree_vm_module_t* check_module = nullptr;
check_native_module_create(iree_allocator_system(), &check_module);
std::array<iree_vm_module_t*, 3> modules = {hal_module, check_module,
input_module};
auto module_signature = iree_vm_module_signature(input_module);
for (iree_host_size_t ordinal = 0;
ordinal < module_signature.export_function_count; ++ordinal) {
iree_vm_function_t function;
iree_string_view_t export_name_sv;
IREE_RETURN_IF_ERROR(iree_vm_module_lookup_function_by_ordinal(
input_module, IREE_VM_FUNCTION_LINKAGE_EXPORT,
ordinal, &function, &export_name_sv),
"looking up function export %zu", ordinal);
// TODO(gcmn): Implicit conversion from iree to absl string view.
auto export_name =
absl::string_view(export_name_sv.data, export_name_sv.size);
iree_string_view_t module_name_iree_sv = iree_vm_module_name(input_module);
auto module_name =
absl::string_view(module_name_iree_sv.data, module_name_iree_sv.size);
if (absl::StartsWith(export_name, "__") ||
export_name.find('$') != absl::string_view::npos) {
// Skip internal or special functions.
continue;
}
IREE_RETURN_IF_ERROR(ValidateFunctionAbi(function));
std::vector<RawSignatureParser::Description> input_descs;
IREE_RETURN_IF_ERROR(ParseInputSignature(function, &input_descs));
std::vector<RawSignatureParser::Description> output_descs;
IREE_RETURN_IF_ERROR(ParseOutputSignature(function, &output_descs));
if (!input_descs.empty() || !output_descs.empty()) {
iree_string_view_t sig_f = iree_vm_function_reflection_attr(
&function, iree_make_cstring_view("f"));
RawSignatureParser sig_parser;
auto sig_str = sig_parser.FunctionSignatureToString(
absl::string_view{sig_f.data, sig_f.size});
if (!sig_str.has_value()) {
return iree_make_status(
IREE_STATUS_INVALID_ARGUMENT,
"parsing function signature '%.*s': ", (int)sig_f.size, sig_f.data);
}
return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
"expected function with no inputs or outputs, "
"but export '%.*s' has signature '%.*s'",
(int)export_name.size(), export_name.data(),
(int)sig_f.size, sig_f.data);
}
::testing::RegisterTest(
module_name.data(), export_name.data(), nullptr,
std::to_string(ordinal).c_str(), __FILE__, __LINE__,
[&instance, modules, function]() -> CheckModuleTest* {
return new CheckModuleTest(instance, modules, function);
});
}
*out_exit_code = RUN_ALL_TESTS();
iree_vm_module_release(hal_module);
iree_vm_module_release(check_module);
iree_vm_module_release(input_module);
iree_hal_device_release(device);
iree_vm_instance_release(instance);
return OkStatus();
}
} // namespace
extern "C" int main(int argc, char** argv) {
// Pass through flags to gtest (allowing --help to fall through).
iree_flags_parse_checked(IREE_FLAGS_PARSE_MODE_UNDEFINED_OK |
IREE_FLAGS_PARSE_MODE_CONTINUE_AFTER_HELP,
&argc, &argv);
IREE_CHECK_OK(iree_hal_register_all_available_drivers(
iree_hal_driver_registry_default()));
::testing::InitGoogleTest(&argc, argv);
IREE_FORCE_BINARY_STDIN();
if (argc < 2) {
IREE_LOG(ERROR)
<< "A binary module file path to run (or - for stdin) must be passed";
return -1;
}
auto module_file_path = std::string(argv[1]);
int exit_code = 1;
auto status = Run(std::move(module_file_path), &exit_code);
int ret = status.ok() ? exit_code : 1;
if (FLAG_expect_failure) {
if (ret == 0) {
std::cout << "Test passed but expected failure\n";
std::cout << status;
return 1;
}
std::cout << "Test failed as expected\n";
return 0;
}
if (ret != 0) {
std::cout << "Test failed\n";
std::cout << status;
}
return ret;
}
} // namespace iree