blob: 775a268cc975e833a18aaf0fd8ad7c9e050bb7d9 [file] [log] [blame]
// Copyright 2022 The IREE Authors
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
// Tests covering the dispatch logic for individual ops.
//
// iree/vm/test/async_ops.mlir contains the functions used here for testing. We
// avoid defining the IR inline here so that we can run this test on platforms
// that we can't run the full MLIR compiler stack on.
#include "iree/base/status_cc.h"
#include "iree/testing/gtest.h"
#include "iree/testing/status_matchers.h"
#include "iree/vm/api.h"
#include "iree/vm/bytecode_module.h"
// Compiled module embedded here to avoid file IO:
#include "iree/vm/test/async_bytecode_modules.h"
namespace iree {
namespace {
using iree::testing::status::StatusIs;
class VMBytecodeDispatchAsyncTest : public ::testing::Test {
protected:
static void SetUpTestSuite() {
IREE_CHECK_OK(iree_vm_register_builtin_types());
}
void SetUp() override {
IREE_TRACE_SCOPE();
const iree_file_toc_t* file = async_bytecode_modules_c_create();
IREE_CHECK_OK(iree_vm_instance_create(iree_allocator_system(), &instance_));
IREE_CHECK_OK(iree_vm_bytecode_module_create(
iree_const_byte_span_t{reinterpret_cast<const uint8_t*>(file->data),
file->size},
iree_allocator_null(), iree_allocator_system(), &bytecode_module_));
std::vector<iree_vm_module_t*> modules = {bytecode_module_};
IREE_CHECK_OK(iree_vm_context_create_with_modules(
instance_, IREE_VM_CONTEXT_FLAG_NONE, modules.size(), modules.data(),
iree_allocator_system(), &context_));
}
void TearDown() override {
IREE_TRACE_SCOPE();
iree_vm_module_release(bytecode_module_);
iree_vm_context_release(context_);
iree_vm_instance_release(instance_);
}
iree_vm_instance_t* instance_ = nullptr;
iree_vm_context_t* context_ = nullptr;
iree_vm_module_t* bytecode_module_ = nullptr;
};
// Tests a simple straight-line yield sequence that requires 3 resumes.
// See iree/vm/test/async_ops.mlir > @yield_sequence
TEST_F(VMBytecodeDispatchAsyncTest, YieldSequence) {
IREE_TRACE_SCOPE();
iree_vm_function_t function;
IREE_ASSERT_OK(iree_vm_module_lookup_function_by_name(
bytecode_module_, IREE_VM_FUNCTION_LINKAGE_EXPORT,
IREE_SV("yield_sequence"), &function));
IREE_VM_INLINE_STACK_INITIALIZE(
stack, IREE_VM_CONTEXT_FLAG_NONE, iree_vm_context_id(context_),
iree_vm_context_state_resolver(context_), iree_allocator_system());
uint32_t arg_value = 97;
uint32_t ret_value = 0;
iree_vm_function_call_t call;
memset(&call, 0, sizeof(call));
call.function = function;
call.arguments = iree_make_byte_span(&arg_value, sizeof(arg_value));
call.results = iree_make_byte_span(&ret_value, sizeof(ret_value));
iree_vm_execution_result_t result;
// 0/3
ASSERT_THAT(
function.module->begin_call(function.module->self, stack, &call, &result),
StatusIs(StatusCode::kDeferred));
// 1/3
ASSERT_THAT(function.module->resume_call(function.module->self, stack,
call.results, &result),
StatusIs(StatusCode::kDeferred));
// 2/3
ASSERT_THAT(function.module->resume_call(function.module->self, stack,
call.results, &result),
StatusIs(StatusCode::kDeferred));
// 3/3
IREE_ASSERT_OK(function.module->resume_call(function.module->self, stack,
call.results, &result));
ASSERT_EQ(ret_value, arg_value + 3);
iree_vm_stack_deinitialize(stack);
}
// Tests a yield with data-dependent control, ensuring that we run the
// alternating branches and pass along branch args on resume.
// See iree/vm/test/async_ops.mlir > @yield_divergent
TEST_F(VMBytecodeDispatchAsyncTest, YieldDivergent) {
IREE_TRACE_SCOPE();
iree_vm_function_t function;
IREE_ASSERT_OK(iree_vm_module_lookup_function_by_name(
bytecode_module_, IREE_VM_FUNCTION_LINKAGE_EXPORT,
IREE_SV("yield_divergent"), &function));
IREE_VM_INLINE_STACK_INITIALIZE(
stack, IREE_VM_CONTEXT_FLAG_NONE, iree_vm_context_id(context_),
iree_vm_context_state_resolver(context_), iree_allocator_system());
// result = %arg0 ? %arg1 : %arg2
struct {
uint32_t arg0;
uint32_t arg1;
uint32_t arg2;
} arg_values = {
0,
100,
200,
};
uint32_t ret_value = 0;
iree_vm_function_call_t call;
memset(&call, 0, sizeof(call));
call.function = function;
call.arguments = iree_make_byte_span(&arg_values, sizeof(arg_values));
call.results = iree_make_byte_span(&ret_value, sizeof(ret_value));
iree_vm_execution_result_t result;
// arg0=0: result = %arg0 ? %arg1 : %arg2 => %arg2
arg_values.arg0 = 0;
ASSERT_THAT(
function.module->begin_call(function.module->self, stack, &call, &result),
StatusIs(StatusCode::kDeferred));
IREE_ASSERT_OK(function.module->resume_call(function.module->self, stack,
call.results, &result));
ASSERT_EQ(ret_value, arg_values.arg2);
// arg0=1: result = %arg0 ? %arg1 : %arg2 => %arg1
arg_values.arg0 = 1;
ASSERT_THAT(
function.module->begin_call(function.module->self, stack, &call, &result),
StatusIs(StatusCode::kDeferred));
IREE_ASSERT_OK(function.module->resume_call(function.module->self, stack,
call.results, &result));
ASSERT_EQ(ret_value, arg_values.arg1);
iree_vm_stack_deinitialize(stack);
}
} // namespace
} // namespace iree