blob: 2db4bd8a50c1fdddf02948d8bba7d7bdba1d9942 [file] [log] [blame]
// Copyright 2020 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "iree/hal/llvmjit/llvmjit_executable.h"
#include <iostream>
#include <memory>
#include "iree/base/tracing.h"
#include "iree/hal/buffer.h"
#include "iree/hal/executable.h"
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/AsmParser/Parser.h"
#include "llvm/ExecutionEngine/Orc/ThreadSafeModule.h"
#include "llvm/IR/LLVMContext.h"
#include "llvm/Support/Error.h"
#include "llvm/Support/MemoryBuffer.h"
#include "llvm/Support/SourceMgr.h"
// flatcc schemas:
#include "iree/base/flatcc.h"
#include "iree/schemas/llvmir_executable_def_reader.h"
#include "iree/schemas/llvmir_executable_def_verifier.h"
// NOTE: starting to port this to C.
// Verifies the structure of the flatbuffer so that we can avoid doing so during
// runtime. There are still some conditions we must be aware of (such as omitted
// names on functions with internal linkage), however we shouldn't need to
// bounds check anything within the flatbuffer after this succeeds.
static iree_status_t iree_hal_llvmir_executable_flatbuffer_verify(
iree_const_byte_span_t flatbuffer_data) {
if (!flatbuffer_data.data || flatbuffer_data.data_length < 16) {
return iree_make_status(
IREE_STATUS_INVALID_ARGUMENT,
"flatbuffer data is not present or less than 16 bytes (%zu total)",
flatbuffer_data.data_length);
}
// Run flatcc generated verification. This ensures all pointers are in-bounds
// and that we can safely walk the file, but not that the actual contents of
// the flatbuffer meet our expectations.
int verify_ret = iree_LLVMIRExecutableDef_verify_as_root(
flatbuffer_data.data, flatbuffer_data.data_length);
if (verify_ret != flatcc_verify_ok) {
return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
"flatbuffer verification failed: %s",
flatcc_verify_error_string(verify_ret));
}
iree_LLVMIRExecutableDef_table_t executable_def =
iree_LLVMIRExecutableDef_as_root(flatbuffer_data.data);
flatbuffers_string_vec_t entry_points_vec =
iree_LLVMIRExecutableDef_entry_points_get(executable_def);
size_t entry_point_count = flatbuffers_string_vec_len(entry_points_vec);
for (size_t i = 0; i < entry_point_count; ++i) {
if (!flatbuffers_string_len(
flatbuffers_string_vec_at(entry_points_vec, i))) {
return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
"executable entry point %zu has no name", i);
}
}
if (!flatbuffers_uint8_vec_len(
iree_LLVMIRExecutableDef_bitcode_module_get(executable_def))) {
return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
"executable bitcode_module is missing/empty");
}
return iree_ok_status();
}
namespace iree {
namespace hal {
namespace llvmjit {
// static
StatusOr<ref_ptr<LLVMJITExecutable>> LLVMJITExecutable::Load(
ExecutableSpec spec, bool allow_aliasing_data) {
IREE_TRACE_SCOPE0("LLVMJITExecutable::Load");
// Verify and fetch the executable flatbuffer wrapper.
iree_const_byte_span_t executable_data = iree_make_const_byte_span(
spec.executable_data.data(), spec.executable_data.size());
IREE_RETURN_IF_ERROR(
iree_hal_llvmir_executable_flatbuffer_verify(executable_data));
iree_LLVMIRExecutableDef_table_t executable_def =
iree_LLVMIRExecutableDef_as_root(executable_data.data);
flatbuffers_uint8_vec_t bitcode_module_vec =
iree_LLVMIRExecutableDef_bitcode_module_get(executable_def);
auto mem_buffer = llvm::MemoryBuffer::getMemBufferCopy(
llvm::StringRef(reinterpret_cast<const char*>(bitcode_module_vec),
flatbuffers_uint8_vec_len(bitcode_module_vec)),
"llvm-ir");
auto llvm_context = std::make_unique<llvm::LLVMContext>();
llvm::SMDiagnostic sm_diagnostic;
auto module = llvm::parseAssembly(*mem_buffer, sm_diagnostic, *llvm_context);
if (!module) {
return InvalidArgumentErrorBuilder(IREE_LOC)
<< "Can't parse LLVMIR Module: " << sm_diagnostic.getMessage().str();
}
auto dataLayout = module->getDataLayout();
llvm::orc::ThreadSafeModule thread_safe_module(std::move(module),
std::move(llvm_context));
auto ll_jit = llvm::cantFail(llvm::orc::LLJITBuilder().create());
llvm::Error err = ll_jit->addIRModule(std::move(thread_safe_module));
if (err) {
return InvalidArgumentErrorBuilder(IREE_LOC)
<< "Can't add executable module to executable LLJIT"
<< llvm::toString(std::move(err));
}
auto llvmjit_serarch_generator =
llvm::orc::DynamicLibrarySearchGenerator::GetForCurrentProcess(
dataLayout.getGlobalPrefix());
if (!llvmjit_serarch_generator) {
return UnavailableErrorBuilder(IREE_LOC)
<< "Can't resolve symbols in current process: "
<< llvm::toString(llvmjit_serarch_generator.takeError());
}
auto& main_jitllvmjit = ll_jit->getMainJITDylib();
main_jitllvmjit.addGenerator(std::move(llvmjit_serarch_generator.get()));
auto executable =
make_ref<LLVMJITExecutable>(spec, std::move(ll_jit), allow_aliasing_data);
flatbuffers_string_vec_t entry_points =
iree_LLVMIRExecutableDef_entry_points_get(executable_def);
executable->symbols_.resize(flatbuffers_string_vec_len(entry_points));
for (size_t i = 0; i < flatbuffers_string_vec_len(entry_points); ++i) {
flatbuffers_string_t entry_point =
flatbuffers_string_vec_at(entry_points, i);
auto func_symbol = executable->ll_jit_->lookup(
llvm::StringRef(entry_point, flatbuffers_string_len(entry_point)));
if (!func_symbol) {
return NotFoundErrorBuilder(IREE_LOC)
<< "Can't JIT compile function '" << entry_point
<< "': " << llvm::toString(func_symbol.takeError());
}
executable->symbols_[i] = func_symbol.get();
}
return executable;
}
LLVMJITExecutable::LLVMJITExecutable(ExecutableSpec spec,
std::unique_ptr<llvm::orc::LLJIT> ll_jit,
bool allow_aliasing_data)
: spec_(spec), ll_jit_(std::move(ll_jit)) {
if (!allow_aliasing_data) {
// Clone data.
cloned_executable_data_ = {spec.executable_data.begin(),
spec.executable_data.end()};
spec_.executable_data = absl::MakeConstSpan(cloned_executable_data_);
}
}
LLVMJITExecutable::~LLVMJITExecutable() = default;
struct LLVMJITDispatchState : public HostExecutable::DispatchState {
LLVMJITDispatchState() = default;
std::array<uint32_t, 3> workgroup_count;
std::array<uint32_t, 3> workgroup_size;
llvm::JITEvaluatedSymbol symbol;
llvm::SmallVector<void*, 4> args;
llvm::SmallVector<int32_t, 4> push_constant;
};
StatusOr<ref_ptr<HostExecutable::DispatchState>>
LLVMJITExecutable::PrepareDispatch(const DispatchParams& params) {
IREE_TRACE_SCOPE0("LLVMJITExecutable::PrepareDispatch");
if (params.entry_point >= symbols_.size()) {
return InvalidArgumentErrorBuilder(IREE_LOC)
<< "Invalid entry point ordinal " << params.entry_point;
}
auto dispatch_state = make_ref<LLVMJITDispatchState>();
dispatch_state->workgroup_count = params.workgroup_count;
dispatch_state->workgroup_size = params.workgroup_size;
dispatch_state->symbol = symbols_[params.entry_point];
for (size_t set = 0; set < params.set_bindings.size(); ++set) {
for (size_t binding = 0; binding < params.set_bindings[set].size();
++binding) {
const auto& io_binding = params.set_bindings[set][binding];
IREE_ASSIGN_OR_RETURN(auto memory,
io_binding.buffer->MapMemory<uint8_t>(
MemoryAccessBitfield::kWrite, io_binding.offset,
io_binding.length));
auto data = memory.mutable_data();
dispatch_state->args.push_back(data);
}
}
// TODO(ataei): Consider moving this casting to codegen side ?!
for (int i = 0; i < params.push_constants->values.size(); ++i) {
dispatch_state->push_constant.push_back(params.push_constants->values[i]);
}
return std::move(dispatch_state);
}
Status LLVMJITExecutable::DispatchTile(DispatchState* state,
std::array<uint32_t, 3> workgroup_xyz) {
IREE_TRACE_SCOPE0("LLVMJITExecutable::DispatchTile");
auto* dispatch_state = static_cast<LLVMJITDispatchState*>(state);
auto func_ptr = (void (*)(void**, int32_t*, uint32_t*, uint32_t*,
uint32_t*))dispatch_state->symbol.getAddress();
func_ptr(dispatch_state->args.data(), dispatch_state->push_constant.data(),
workgroup_xyz.data(), dispatch_state->workgroup_count.data(),
dispatch_state->workgroup_size.data());
return OkStatus();
}
} // namespace llvmjit
} // namespace hal
} // namespace iree