blob: 58bb5ff9c1b53cfcbe7a5d62352bfedf9feb8bf3 [file] [log] [blame]
// Copyright 2021 The IREE Authors
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
#include "experimental/webgpu/bind_group_cache.h"
#include <assert.h>
#include <stddef.h>
#include <stdint.h>
#include "experimental/webgpu/buffer.h"
#include "iree/base/api.h"
#include "iree/base/internal/math.h"
void iree_hal_webgpu_bind_group_cache_initialize(
WGPUDevice device, iree_hal_webgpu_bind_group_cache_t* out_cache) {
IREE_ASSERT_ARGUMENT(out_cache);
IREE_TRACE_ZONE_BEGIN(z0);
out_cache->device = device;
out_cache->entry_count = IREE_HAL_WEBGPU_BIND_GROUP_CACHE_CAPACITY;
IREE_TRACE_ZONE_END(z0);
}
void iree_hal_webgpu_bind_group_cache_deinitialize(
iree_hal_webgpu_bind_group_cache_t* cache) {
IREE_ASSERT_ARGUMENT(cache);
IREE_TRACE_ZONE_BEGIN(z0);
// Trim is the same as deinit today.
iree_hal_webgpu_bind_group_cache_trim(cache);
cache->entry_count = 0;
IREE_TRACE_ZONE_END(z0);
}
void iree_hal_webgpu_bind_group_cache_trim(
iree_hal_webgpu_bind_group_cache_t* cache) {
IREE_ASSERT_ARGUMENT(cache);
IREE_TRACE_ZONE_BEGIN(z0);
for (iree_host_size_t i = 0; i < cache->entry_count; ++i) {
iree_hal_webgpu_bind_group_cache_entry_t* entry = &cache->entries[i];
if (entry->handle) iree_wgpuBindGroupDrop(entry->handle);
}
memset(cache->entries, 0, sizeof(cache->entries));
IREE_TRACE_ZONE_END(z0);
}
WGPUBindGroup iree_hal_webgpu_bind_group_cache_acquire(
iree_hal_webgpu_bind_group_cache_t* cache, WGPUBindGroupLayout group_layout,
const iree_hal_webgpu_bind_group_binding_t* bindings,
iree_hal_webgpu_binding_mask_t binding_mask) {
IREE_ASSERT_ARGUMENT(cache);
IREE_ASSERT_ARGUMENT(bindings);
IREE_TRACE_ZONE_BEGIN(z0);
// This is not a good algorithm :)
// We should probably have a split index and a mechanism to partition it such
// that lookups don't need to perform a full scan. This is cheaper than the
// cost of creating a new bind group per dispatch (no need to call out to
// WebGPU, allocate new objects, track those new objects lifetimes, etc) but
// not cheap. Ideally we'd be relying on this for 4-5 bind groups per command
// buffer at which point it doesn't matter but the compiler still needs to
// improve a bit there.
// Scan the cache for entries with a matching group layout and binding mask.
// These should be the same today but in the future we may want to allow for
// subsetting as defined by bind group compatibility.
iree_host_size_t insertion_slot = cache->entry_count - 1;
for (iree_host_size_t i = 0; i < cache->entry_count; ++i) {
iree_hal_webgpu_bind_group_cache_entry_t* entry = &cache->entries[i];
if (!entry->handle) {
insertion_slot = iree_min(insertion_slot, i);
continue;
}
if (entry->group_layout != group_layout) continue;
if (entry->binding_mask != binding_mask) continue;
// Found a potential match. Do a full comparison of the bindings.
// TODO(benvanik): we only really need to compare the bindings that are
// set in the mask, however memcmp over a few hundred bytes is usually
// faster than what we'd have to do for that comparison.
if (memcmp(bindings, entry->bindings, sizeof(entry->bindings)) == 0) {
// Same exact bindings - cache hit!
// TODO(benvanik): a real LRU that rearranges this to the front/back.
IREE_TRACE_ZONE_END(z0);
return entry->handle;
}
}
// Evict an existing entry to store this new one or use the last unused slot.
iree_hal_webgpu_bind_group_cache_entry_t* entry =
&cache->entries[insertion_slot];
if (entry->handle) {
IREE_TRACE_ZONE_APPEND_TEXT(z0, "evict");
iree_wgpuBindGroupDrop(entry->handle);
} else {
IREE_TRACE_ZONE_APPEND_TEXT(z0, "miss");
}
entry->group_layout = group_layout;
entry->binding_mask = binding_mask;
memcpy(entry->bindings, bindings, sizeof(entry->bindings));
// NOTE: we could change this to do bit scans over the binding_mask but I
// haven't checked to see how expensive those are in WebAssembly. For now we
// do a few more loop iterations with the assumption that doing a bit scan
// may require hundreds of more instructions.
uint32_t binding_count = 0;
WGPUBindGroupEntry entries[IREE_HAL_WEBGPU_MAX_DESCRIPTOR_SET_BINDING_COUNT];
for (iree_host_size_t i = 0; i < IREE_ARRAYSIZE(entries); ++i) {
if (!(binding_mask & (1u << i))) continue;
entries[binding_count] = (WGPUBindGroupEntry){
.nextInChain = NULL,
.binding = binding_count,
.buffer = bindings[i].buffer,
.offset = (uint64_t)bindings[i].offset,
.size = bindings[i].length,
};
++binding_count;
}
const WGPUBindGroupDescriptor descriptor = {
.nextInChain = NULL,
.label = NULL,
.layout = group_layout,
.entryCount = binding_count,
.entries = entries,
};
entry->handle = wgpuDeviceCreateBindGroup(cache->device, &descriptor);
IREE_TRACE_ZONE_END(z0);
return entry->handle;
}