blob: f3d79cfbb12bc576daa3bcad294a314c4b0bb44f [file] [log] [blame]
// Copyright 2020 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "iree/hal/dylib/dylib_executable.h"
#include "flatbuffers/flatbuffers.h"
#include "iree/base/file_io.h"
#include "iree/base/file_path.h"
#include "iree/schemas/dylib_executable_def_generated.h"
namespace iree {
namespace hal {
namespace dylib {
// static
StatusOr<ref_ptr<DyLibExecutable>> DyLibExecutable::Load(ExecutableSpec spec) {
auto executable = make_ref<DyLibExecutable>();
IREE_RETURN_IF_ERROR(executable->Initialize(spec));
return executable;
}
DyLibExecutable::DyLibExecutable() = default;
DyLibExecutable::~DyLibExecutable() {
IREE_TRACE_SCOPE0("DyLibExecutable::dtor");
// TODO(benvanik): move to an atexit handler when tracing is enabled.
// executable_library_.release();
executable_library_.reset();
for (const auto& file_path : temp_file_paths_) {
file_io::DeleteFile(file_path).IgnoreError();
}
}
Status DyLibExecutable::Initialize(ExecutableSpec spec) {
IREE_TRACE_SCOPE0("DyLibExecutable::Initialize");
auto dylib_executable_def =
::flatbuffers::GetRoot<DyLibExecutableDef>(spec.executable_data.data());
if (!dylib_executable_def->entry_points() ||
dylib_executable_def->entry_points()->size() == 0) {
return InvalidArgumentErrorBuilder(IREE_LOC) << "No entry points defined";
}
if (!dylib_executable_def->library_embedded() ||
dylib_executable_def->library_embedded()->size() == 0) {
return InvalidArgumentErrorBuilder(IREE_LOC) << "No embedded library";
}
// Write the embedded library out to a temp file, since all of the dynamic
// library APIs work with files. We could instead use in-memory files on
// platforms where that is convenient.
std::string base_name = "dylib_executable";
IREE_ASSIGN_OR_RETURN(auto library_temp_path,
file_io::GetTempFile(base_name));
temp_file_paths_.push_back(library_temp_path);
// Add platform-specific file extensions so opinionated dynamic library
// loaders are more likely to find the file:
#if defined(IREE_PLATFORM_WINDOWS)
library_temp_path += ".dll";
#else
library_temp_path += ".so";
#endif
absl::string_view embedded_library_data(
reinterpret_cast<const char*>(
dylib_executable_def->library_embedded()->data()),
dylib_executable_def->library_embedded()->size());
IREE_RETURN_IF_ERROR(
file_io::SetFileContents(library_temp_path, embedded_library_data));
IREE_ASSIGN_OR_RETURN(executable_library_,
DynamicLibrary::Load(library_temp_path.c_str()));
if (dylib_executable_def->debug_database_filename() &&
dylib_executable_def->debug_database_embedded()) {
IREE_TRACE_SCOPE0("DyLibExecutable::AttachDebugDatabase");
absl::string_view debug_database_filename(
dylib_executable_def->debug_database_filename()->data(),
dylib_executable_def->debug_database_filename()->size());
absl::string_view debug_database_data(
reinterpret_cast<const char*>(
dylib_executable_def->debug_database_embedded()->data()),
dylib_executable_def->debug_database_embedded()->size());
auto debug_database_path = file_path::JoinPaths(
file_path::DirectoryName(library_temp_path), debug_database_filename);
temp_file_paths_.push_back(debug_database_path);
IREE_IGNORE_ERROR(
file_io::SetFileContents(debug_database_path, debug_database_data));
executable_library_->AttachDebugDatabase(debug_database_path.c_str());
}
const auto& entry_points = *dylib_executable_def->entry_points();
entry_functions_.resize(entry_points.size());
IREE_TRACE(entry_names_.resize(entry_points.size()));
for (int i = 0; i < entry_functions_.size(); ++i) {
void* symbol = executable_library_->GetSymbol(entry_points[i]->c_str());
if (!symbol) {
return NotFoundErrorBuilder(IREE_LOC)
<< "Could not find symbol: " << entry_points[i];
}
entry_functions_[i] = symbol;
IREE_TRACE(entry_names_[i] = entry_points[i]->c_str());
}
return OkStatus();
}
struct DyLibDispatchState : public HostExecutable::DispatchState {
DyLibDispatchState() = default;
IREE_TRACE(const char* entry_name = nullptr);
void* entry_function = nullptr;
std::array<void*, 32> args;
std::array<uint32_t, 32> push_constants;
};
StatusOr<ref_ptr<HostExecutable::DispatchState>>
DyLibExecutable::PrepareDispatch(const DispatchParams& params) {
IREE_TRACE_SCOPE0("DyLibExecutable::PrepareDispatch");
if (params.entry_point >= entry_functions_.size()) {
return InvalidArgumentErrorBuilder(IREE_LOC)
<< "Invalid entry point ordinal " << params.entry_point;
}
auto dispatch_state = make_ref<DyLibDispatchState>();
IREE_TRACE(dispatch_state->entry_name = entry_names_[params.entry_point]);
dispatch_state->entry_function = entry_functions_[params.entry_point];
int binding_count = 0;
for (size_t set = 0; set < params.set_bindings.size(); ++set) {
for (size_t binding = 0; binding < params.set_bindings[set].size();
++binding) {
const auto& io_binding = params.set_bindings[set][binding];
IREE_ASSIGN_OR_RETURN(auto memory,
io_binding.buffer->MapMemory<uint8_t>(
MemoryAccessBitfield::kWrite, io_binding.offset,
io_binding.length));
auto data = memory.mutable_data();
dispatch_state->args[binding_count++] = data;
}
}
dispatch_state->push_constants = params.push_constants->values;
return std::move(dispatch_state);
}
Status DyLibExecutable::DispatchTile(DispatchState* state,
std::array<uint32_t, 3> workgroup_xyz) {
auto* dispatch_state = static_cast<DyLibDispatchState*>(state);
IREE_TRACE_SCOPE_DYNAMIC(dispatch_state->entry_name);
auto entry_function = (void (*)(void**, uint32_t*, int32_t, int32_t,
int32_t))dispatch_state->entry_function;
entry_function(dispatch_state->args.data(),
dispatch_state->push_constants.data(), workgroup_xyz[0],
workgroup_xyz[1], workgroup_xyz[2]);
return OkStatus();
}
} // namespace dylib
} // namespace hal
} // namespace iree