blob: 9565e48bcbe451ce244bc9751fa2baaa455f70ad [file] [log] [blame]
// Copyright 2020 The IREE Authors
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
#include "iree/hal/local/local_executable.h"
#include "iree/hal/local/executable_environment.h"
void iree_hal_local_executable_initialize(
const iree_hal_local_executable_vtable_t* vtable,
iree_host_size_t pipeline_layout_count,
iree_hal_pipeline_layout_t* const* source_pipeline_layouts,
iree_hal_pipeline_layout_t** target_pipeline_layouts,
iree_allocator_t host_allocator,
iree_hal_local_executable_t* out_base_executable) {
iree_hal_resource_initialize(vtable, &out_base_executable->resource);
out_base_executable->host_allocator = host_allocator;
out_base_executable->pipeline_layout_count = pipeline_layout_count;
out_base_executable->pipeline_layouts = target_pipeline_layouts;
for (iree_host_size_t i = 0; i < pipeline_layout_count; ++i) {
target_pipeline_layouts[i] = source_pipeline_layouts[i];
iree_hal_pipeline_layout_retain(source_pipeline_layouts[i]);
}
// Function attributes are optional and populated by the parent type.
out_base_executable->dispatch_attrs = NULL;
// Default environment with no imports assigned.
iree_hal_executable_environment_initialize(host_allocator,
&out_base_executable->environment);
}
void iree_hal_local_executable_deinitialize(
iree_hal_local_executable_t* base_executable) {
for (iree_host_size_t i = 0; i < base_executable->pipeline_layout_count;
++i) {
iree_hal_pipeline_layout_release(base_executable->pipeline_layouts[i]);
}
}
iree_hal_local_executable_t* iree_hal_local_executable_cast(
iree_hal_executable_t* base_value) {
return (iree_hal_local_executable_t*)base_value;
}
iree_status_t iree_hal_local_executable_issue_call(
iree_hal_local_executable_t* executable, iree_host_size_t ordinal,
const iree_hal_executable_dispatch_state_v0_t* dispatch_state,
const iree_hal_executable_workgroup_state_v0_t* workgroup_state,
uint32_t worker_id) {
IREE_ASSERT_ARGUMENT(executable);
IREE_ASSERT_ARGUMENT(dispatch_state);
IREE_ASSERT_ARGUMENT(workgroup_state);
return ((const iree_hal_local_executable_vtable_t*)
executable->resource.vtable)
->issue_call(executable, ordinal, dispatch_state, workgroup_state,
worker_id);
}
iree_status_t iree_hal_local_executable_issue_dispatch_inline(
iree_hal_local_executable_t* executable, iree_host_size_t ordinal,
const iree_hal_executable_dispatch_state_v0_t* dispatch_state,
uint32_t processor_id, iree_byte_span_t local_memory) {
IREE_TRACE_ZONE_BEGIN(z0);
// TODO(benvanik): annotate with executable name to calculate total time.
const uint32_t workgroup_count_x = dispatch_state->workgroup_count_x;
const uint32_t workgroup_count_y = dispatch_state->workgroup_count_y;
const uint32_t workgroup_count_z = dispatch_state->workgroup_count_z;
#if IREE_HAL_VERBOSE_TRACING_ENABLE
// TODO(benvanik): tracing.h helper that speeds this up; too slow.
IREE_TRACE({
char xyz_string[32];
int xyz_string_length =
snprintf(xyz_string, IREE_ARRAYSIZE(xyz_string), "%ux%ux%u",
workgroup_count_x, workgroup_count_y, workgroup_count_z);
IREE_TRACE_ZONE_APPEND_TEXT_STRING_VIEW(z0, xyz_string, xyz_string_length);
});
#endif // IREE_HAL_VERBOSE_TRACING_ENABLE
iree_status_t status = iree_ok_status();
iree_alignas(64) iree_hal_executable_workgroup_state_v0_t workgroup_state = {
.workgroup_id_x = 0,
.workgroup_id_y = 0,
.workgroup_id_z = 0,
.processor_id = processor_id,
.local_memory = local_memory.data,
.local_memory_size = (size_t)local_memory.data_length,
};
for (uint32_t z = 0; z < workgroup_count_z; ++z) {
workgroup_state.workgroup_id_z = z;
for (uint32_t y = 0; y < workgroup_count_y; ++y) {
workgroup_state.workgroup_id_y = y;
for (uint32_t x = 0; x < workgroup_count_x; ++x) {
workgroup_state.workgroup_id_x = x;
status = iree_hal_local_executable_issue_call(
executable, ordinal, dispatch_state, &workgroup_state,
/*worker_id=*/0);
if (!iree_status_is_ok(status)) break;
}
}
}
IREE_TRACE_ZONE_END(z0);
return status;
}