blob: c5d1cc8cf21f41b016314c82fe0b46b947edd2db [file] [log] [blame]
// Copyright 2019 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "iree/hal/host/serial/serial_command_processor.h"
#include "iree/base/status.h"
#include "iree/base/tracing.h"
#include "iree/hal/host/host_descriptor_set.h"
#include "iree/hal/host/host_executable_layout.h"
namespace iree {
namespace hal {
namespace host {
SerialCommandProcessor::SerialCommandProcessor(
CommandCategoryBitfield command_categories)
: CommandBuffer(CommandBufferMode::kOneShot, command_categories) {}
SerialCommandProcessor::~SerialCommandProcessor() = default;
Status SerialCommandProcessor::Begin() {
IREE_TRACE_SCOPE0("SerialCommandProcessor::Begin");
is_recording_ = true;
return OkStatus();
}
Status SerialCommandProcessor::End() {
IREE_TRACE_SCOPE0("SerialCommandProcessor::End");
is_recording_ = false;
return OkStatus();
}
Status SerialCommandProcessor::ExecutionBarrier(
ExecutionStageBitfield source_stage_mask,
ExecutionStageBitfield target_stage_mask,
absl::Span<const MemoryBarrier> memory_barriers,
absl::Span<const BufferBarrier> buffer_barriers) {
IREE_TRACE_SCOPE0("SerialCommandProcessor::ExecutionBarrier");
// No-op.
return OkStatus();
}
Status SerialCommandProcessor::SignalEvent(
Event* event, ExecutionStageBitfield source_stage_mask) {
IREE_TRACE_SCOPE0("SerialCommandProcessor::SignalEvent");
// No-op.
return OkStatus();
}
Status SerialCommandProcessor::ResetEvent(
Event* event, ExecutionStageBitfield source_stage_mask) {
IREE_TRACE_SCOPE0("SerialCommandProcessor::ResetEvent");
// No-op.
return OkStatus();
}
Status SerialCommandProcessor::WaitEvents(
absl::Span<Event*> events, ExecutionStageBitfield source_stage_mask,
ExecutionStageBitfield target_stage_mask,
absl::Span<const MemoryBarrier> memory_barriers,
absl::Span<const BufferBarrier> buffer_barriers) {
IREE_TRACE_SCOPE0("SerialCommandProcessor::WaitEvents");
// No-op.
return OkStatus();
}
Status SerialCommandProcessor::FillBuffer(Buffer* target_buffer,
device_size_t target_offset,
device_size_t length,
const void* pattern,
size_t pattern_length) {
IREE_TRACE_SCOPE0("SerialCommandProcessor::FillBuffer");
return target_buffer->Fill(target_offset, length, pattern, pattern_length);
}
Status SerialCommandProcessor::DiscardBuffer(Buffer* buffer) {
IREE_TRACE_SCOPE0("SerialCommandProcessor::DiscardBuffer");
// No-op as we don't support lazily allocated buffers.
return OkStatus();
}
Status SerialCommandProcessor::UpdateBuffer(const void* source_buffer,
device_size_t source_offset,
Buffer* target_buffer,
device_size_t target_offset,
device_size_t length) {
IREE_TRACE_SCOPE0("SerialCommandProcessor::UpdateBuffer");
return target_buffer->WriteData(
target_offset, static_cast<const uint8_t*>(source_buffer) + source_offset,
length);
}
Status SerialCommandProcessor::CopyBuffer(Buffer* source_buffer,
device_size_t source_offset,
Buffer* target_buffer,
device_size_t target_offset,
device_size_t length) {
IREE_TRACE_SCOPE0("SerialCommandProcessor::CopyBuffer");
return target_buffer->CopyData(target_offset, source_buffer, source_offset,
length);
}
Status SerialCommandProcessor::PushConstants(
ExecutableLayout* executable_layout, size_t offset,
absl::Span<const uint32_t> values) {
IREE_TRACE_SCOPE0("SerialCommandProcessor::PushConstants");
if (offset + values.size() > push_constants_.values.size()) {
return InvalidArgumentErrorBuilder(IREE_LOC)
<< "Push constants out of range";
}
for (int i = 0; i < values.size(); ++i) {
push_constants_.values[offset + i] = values[i];
}
return OkStatus();
}
Status SerialCommandProcessor::PushDescriptorSet(
ExecutableLayout* executable_layout, int32_t set,
absl::Span<const DescriptorSet::Binding> bindings) {
IREE_TRACE_SCOPE0("SerialCommandProcessor::PushDescriptorSet");
if (!AnyBitSet(command_categories() & CommandCategory::kDispatch)) {
return FailedPreconditionErrorBuilder(IREE_LOC)
<< "Command processor does not support dispatch operations";
}
auto* host_executable_layout =
static_cast<HostExecutableLayout*>(executable_layout);
descriptor_sets_.resize(host_executable_layout->set_count());
if (set < 0 || set >= descriptor_sets_.size()) {
return InvalidArgumentErrorBuilder(IREE_LOC)
<< "Set " << set << " out of range (" << descriptor_sets_.size()
<< ")";
}
auto& set_bindings = descriptor_sets_[set];
set_bindings.resize(bindings.size());
for (size_t i = 0; i < bindings.size(); ++i) {
set_bindings[i] = bindings[i];
}
return OkStatus();
}
Status SerialCommandProcessor::BindDescriptorSet(
ExecutableLayout* executable_layout, int32_t set,
DescriptorSet* descriptor_set,
absl::Span<const device_size_t> dynamic_offsets) {
IREE_TRACE_SCOPE0("SerialCommandProcessor::BindDescriptorSet");
if (!AnyBitSet(command_categories() & CommandCategory::kDispatch)) {
return FailedPreconditionErrorBuilder(IREE_LOC)
<< "Command processor does not support dispatch operations";
}
auto* host_executable_layout =
static_cast<HostExecutableLayout*>(executable_layout);
descriptor_sets_.resize(host_executable_layout->set_count());
if (set < 0 || descriptor_sets_.size() >= set) {
return InvalidArgumentErrorBuilder(IREE_LOC)
<< "Set " << set << " out of range (" << descriptor_sets_.size()
<< ")";
}
auto* host_descriptor_set = static_cast<HostDescriptorSet*>(descriptor_set);
auto* set_bindings = &descriptor_sets_[set];
*set_bindings = {host_descriptor_set->bindings().begin(),
host_descriptor_set->bindings().end()};
if (!dynamic_offsets.empty()) {
auto dynamic_binding_map =
host_executable_layout->GetDynamicBindingMap(set);
if (dynamic_offsets.size() != dynamic_binding_map.size()) {
return InvalidArgumentErrorBuilder(IREE_LOC)
<< "Dynamic offset count mismatch (provided "
<< dynamic_offsets.size() << " but expected "
<< dynamic_binding_map.size() << ")";
}
for (int i = 0; i < dynamic_binding_map.size(); ++i) {
(*set_bindings)[dynamic_binding_map[i]].offset += dynamic_offsets[i];
}
}
return OkStatus();
}
Status SerialCommandProcessor::Dispatch(Executable* executable,
int32_t entry_point,
std::array<uint32_t, 3> workgroups) {
IREE_TRACE_SCOPE0("SerialCommandProcessor::Dispatch");
return DispatchGrid(executable, entry_point, workgroups);
}
Status SerialCommandProcessor::DispatchIndirect(
Executable* executable, int32_t entry_point, Buffer* workgroups_buffer,
device_size_t workgroups_offset) {
IREE_TRACE_SCOPE0("SerialCommandProcessor::DispatchIndirect");
std::array<uint32_t, 3> workgroup_count;
IREE_RETURN_IF_ERROR(workgroups_buffer->ReadData(
workgroups_offset, workgroup_count.data(), sizeof(uint32_t) * 3));
return DispatchGrid(executable, entry_point, workgroup_count);
}
Status SerialCommandProcessor::DispatchGrid(
Executable* executable, int32_t entry_point,
std::array<uint32_t, 3> workgroup_count) {
HostExecutable::DispatchParams params;
params.entry_point = entry_point;
params.workgroup_count = workgroup_count;
params.push_constants = &push_constants_;
absl::InlinedVector<absl::Span<const DescriptorSet::Binding>, 2>
descriptor_sets(descriptor_sets_.size());
for (int i = 0; i < descriptor_sets_.size(); ++i) {
descriptor_sets[i] = absl::MakeConstSpan(descriptor_sets_[i]);
}
params.set_bindings = descriptor_sets;
auto* host_executable = reinterpret_cast<HostExecutable*>(executable);
IREE_ASSIGN_OR_RETURN(auto dispatch_state,
host_executable->PrepareDispatch(params));
for (uint32_t z = 0; z < params.workgroup_count[2]; ++z) {
for (uint32_t y = 0; y < params.workgroup_count[1]; ++y) {
for (uint32_t x = 0; x < params.workgroup_count[0]; ++x) {
IREE_RETURN_IF_ERROR(
host_executable->DispatchTile(dispatch_state.get(), {x, y, z}));
}
}
}
return OkStatus();
}
} // namespace host
} // namespace hal
} // namespace iree