blob: b006ae82ab08d6091a6d42f6a83345eb3a0c89b0 [file] [log] [blame]
// Copyright 2019 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "third_party/mlir_edge/iree/hal/host/host_submission_queue.h"
#include <atomic>
#include <cstdint>
#include "third_party/absl/synchronization/mutex.h"
#include "third_party/mlir_edge/iree/base/status.h"
#include "third_party/mlir_edge/iree/base/tracing.h"
namespace iree {
namespace hal {
HostBinarySemaphore::HostBinarySemaphore(bool initial_value) {
State state = {0};
state.signaled = initial_value ? 1 : 0;
state_ = state;
}
bool HostBinarySemaphore::is_signaled() const {
return state_.load(std::memory_order_acquire).signaled == 1;
}
Status HostBinarySemaphore::BeginSignaling() {
State old_state = state_.load(std::memory_order_acquire);
if (old_state.signal_pending != 0) {
return FailedPreconditionErrorBuilder(ABSL_LOC)
<< "A signal operation on a binary semaphore is already pending";
}
State new_state = old_state;
new_state.signal_pending = 1;
state_.compare_exchange_strong(old_state, new_state);
return OkStatus();
}
Status HostBinarySemaphore::EndSignaling() {
State old_state = state_.load(std::memory_order_acquire);
DCHECK_EQ(old_state.signal_pending, 1)
<< "A signal operation on a binary semaphore was not pending";
if (old_state.signaled != 0) {
return FailedPreconditionErrorBuilder(ABSL_LOC)
<< "A binary semaphore cannot be signaled multiple times";
}
State new_state = old_state;
new_state.signal_pending = 0;
new_state.signaled = 1;
state_.compare_exchange_strong(old_state, new_state);
return OkStatus();
}
Status HostBinarySemaphore::BeginWaiting() {
State old_state = state_.load(std::memory_order_acquire);
if (old_state.wait_pending != 0) {
return FailedPreconditionErrorBuilder(ABSL_LOC)
<< "A wait operation on a binary semaphore is already pending";
}
State new_state = old_state;
new_state.wait_pending = 1;
state_.compare_exchange_strong(old_state, new_state);
return OkStatus();
}
Status HostBinarySemaphore::EndWaiting() {
State old_state = state_.load(std::memory_order_acquire);
DCHECK_EQ(old_state.wait_pending, 1)
<< "A wait operation on a binary semaphore was not pending";
if (old_state.signaled != 1) {
return FailedPreconditionErrorBuilder(ABSL_LOC)
<< "A binary semaphore cannot be reset multiple times";
}
State new_state = old_state;
new_state.wait_pending = 0;
new_state.signaled = 0;
state_.compare_exchange_strong(old_state, new_state);
return OkStatus();
}
HostSubmissionQueue::HostSubmissionQueue() = default;
HostSubmissionQueue::~HostSubmissionQueue() = default;
bool HostSubmissionQueue::IsBatchReady(const PendingBatch& batch) const {
for (auto& wait_point : batch.wait_semaphores) {
if (wait_point.index() == 0) {
auto* binary_semaphore =
reinterpret_cast<HostBinarySemaphore*>(absl::get<0>(wait_point));
if (!binary_semaphore->is_signaled()) {
return false;
}
} else {
// TODO(b/140141417): implement timeline semaphores.
return false;
}
}
return true;
}
Status HostSubmissionQueue::Enqueue(absl::Span<const SubmissionBatch> batches,
FenceValue fence) {
IREE_TRACE_SCOPE0("HostSubmissionQueue::Enqueue");
if (has_shutdown_) {
return FailedPreconditionErrorBuilder(ABSL_LOC)
<< "Cannot enqueue new submissions; queue is exiting";
} else if (!permanent_error_.ok()) {
return permanent_error_;
}
// Verify waiting/signaling behavior on semaphores and prepare them all.
// We need to track this to ensure that we are modeling the Vulkan behavior
// and are consistent across HAL implementations.
for (auto& batch : batches) {
for (auto& semaphore_value : batch.wait_semaphores) {
if (semaphore_value.index() == 0) {
auto* binary_semaphore = reinterpret_cast<HostBinarySemaphore*>(
absl::get<0>(semaphore_value));
RETURN_IF_ERROR(binary_semaphore->BeginWaiting());
} else {
// TODO(b/140141417): implement timeline semaphores.
return UnimplementedErrorBuilder(ABSL_LOC) << "Timeline semaphores NYI";
}
}
for (auto& semaphore_value : batch.signal_semaphores) {
if (semaphore_value.index() == 0) {
auto* binary_semaphore = reinterpret_cast<HostBinarySemaphore*>(
absl::get<0>(semaphore_value));
RETURN_IF_ERROR(binary_semaphore->BeginSignaling());
} else {
// TODO(b/140141417): implement timeline semaphores.
return UnimplementedErrorBuilder(ABSL_LOC) << "Timeline semaphores NYI";
}
}
}
// Add to list - order does not matter as Process evaluates semaphores.
auto submission = absl::make_unique<Submission>();
submission->fence = std::move(fence);
submission->pending_batches.resize(batches.size());
for (int i = 0; i < batches.size(); ++i) {
submission->pending_batches[i] = PendingBatch{
{batches[i].wait_semaphores.begin(), batches[i].wait_semaphores.end()},
{batches[i].command_buffers.begin(), batches[i].command_buffers.end()},
{batches[i].signal_semaphores.begin(),
batches[i].signal_semaphores.end()},
};
}
list_.push_back(std::move(submission));
return OkStatus();
}
Status HostSubmissionQueue::ProcessBatches(ExecuteFn execute_fn) {
IREE_TRACE_SCOPE0("HostSubmissionQueue::ProcessBatches");
if (!permanent_error_.ok()) {
// Sticky failure state.
return permanent_error_;
}
// Repeated try to run things until we quiesce or are blocked.
while (permanent_error_.ok() && !list_.empty()) {
// NOTE: to support re-entrancy where |execute_fn| may modify the submission
// list we need to always start from the beginning. If we wanted we could
// track a list of ready submissions however that's a lot of bookkeeping and
// the list is usually short.
bool restart_iteration = false;
for (auto* submission : list_) {
for (int i = 0; i < submission->pending_batches.size(); ++i) {
auto& batch = submission->pending_batches[i];
if (!IsBatchReady(batch)) {
// Try the next batch in the submission until we find one that is
// ready. If none are ready we'll return to the caller.
continue;
}
// Batch can run! Process now and remove it from the list so we don't
// try to run it again.
auto batch_status = ProcessBatch(batch, execute_fn);
submission->pending_batches.erase(submission->pending_batches.begin() +
i);
if (batch_status.ok()) {
// Batch succeeded. Since we want to preserve submission order we'll
// break out of the loop and try from the first submission again.
if (submission->pending_batches.empty()) {
// All work for this submission completed successfully. Signal the
// fence and remove the submission from the list.
RETURN_IF_ERROR(CompleteSubmission(submission, OkStatus()));
list_.take(submission).reset();
}
} else {
// Batch failed; set the permanent error flag and abort so we don't
// try to process anything else.
permanent_error_ = batch_status;
RETURN_IF_ERROR(CompleteSubmission(submission, batch_status));
list_.take(submission).reset();
}
restart_iteration = true;
break;
}
if (restart_iteration) break;
}
}
if (!permanent_error_.ok()) {
// If the sticky error got set while processing we need to abort all
// remaining submissions (simulating a device loss).
FailAllPending(permanent_error_);
return permanent_error_;
}
return OkStatus();
}
Status HostSubmissionQueue::ProcessBatch(const PendingBatch& batch,
const ExecuteFn& execute_fn) {
IREE_TRACE_SCOPE0("HostSubmissionQueue::ProcessBatch");
// Complete the waits on all semaphores and reset them.
for (auto& semaphore_value : batch.wait_semaphores) {
if (semaphore_value.index() == 0) {
auto* binary_semaphore =
reinterpret_cast<HostBinarySemaphore*>(absl::get<0>(semaphore_value));
RETURN_IF_ERROR(binary_semaphore->EndWaiting());
} else {
// TODO(b/140141417): implement timeline semaphores.
return UnimplementedErrorBuilder(ABSL_LOC) << "Timeline semaphores NYI";
}
}
// Let the caller handle execution of the command buffers.
RETURN_IF_ERROR(execute_fn(batch.command_buffers));
// Signal all semaphores to allow them to unblock waiters.
for (auto& semaphore_value : batch.signal_semaphores) {
if (semaphore_value.index() == 0) {
auto* binary_semaphore =
reinterpret_cast<HostBinarySemaphore*>(absl::get<0>(semaphore_value));
RETURN_IF_ERROR(binary_semaphore->EndSignaling());
} else {
// TODO(b/140141417): implement timeline semaphores.
return UnimplementedErrorBuilder(ABSL_LOC) << "Timeline semaphores NYI";
}
}
return OkStatus();
}
Status HostSubmissionQueue::CompleteSubmission(Submission* submission,
Status status) {
IREE_TRACE_SCOPE0("HostSubmissionQueue::CompleteSubmission");
// It's safe to drop any remaining batches - their semaphores will never be
// signaled but that's fine as we should be the only thing relying on them.
submission->pending_batches.clear();
// Signal the fence.
auto* fence = static_cast<HostFence*>(submission->fence.first);
if (status.ok()) {
RETURN_IF_ERROR(fence->Signal(submission->fence.second));
} else {
RETURN_IF_ERROR(fence->Fail(std::move(status)));
}
return OkStatus();
}
void HostSubmissionQueue::FailAllPending(Status status) {
IREE_TRACE_SCOPE0("HostSubmissionQueue::FailAllPending");
while (!list_.empty()) {
auto submission = list_.take(list_.front());
CompleteSubmission(submission.get(), status).IgnoreError();
submission.reset();
}
}
void HostSubmissionQueue::SignalShutdown() {
IREE_TRACE_SCOPE0("HostSubmissionQueue::SignalShutdown");
has_shutdown_ = true;
}
} // namespace hal
} // namespace iree